# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
from tqdm import tqdm
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from LookGenerator.datasets.refinement_dataset import RefinementDiscriminatorDataset
from LookGenerator.networks.refinement import RefinementDiscriminator
from LookGenerator.networks.trainer import Trainer
from LookGenerator.networks_training.utils import check_path_and_creat
from LookGenerator.networks.losses import WassersteinLoss, GradientPenalty
import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [2]:
transform = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [3]:
batch_size_train = 32
batch_size_val = 16
pin_memory = True
num_workers = 0

In [4]:
train_dataset = RefinementDiscriminatorDataset(
    fake_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\train",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train\image",
    transform=transform
)

val_dataset = RefinementDiscriminatorDataset(
    fake_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\val",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\val\image",
    transform=transform
)

In [5]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size_val, shuffle=False, pin_memory=pin_memory, num_workers=num_workers
)

# Обучение модели

In [6]:
model = RefinementDiscriminator()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = WassersteinLoss()
gp = GradientPenalty(model, device=device)

print(device)

cuda


In [7]:
save_directory=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\discriminator_pretrained\session1"
check_path_and_creat(save_directory)

True

In [8]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [9]:
from torchsummary import summary
model.to(device)
summary(model.features, (3, 256, 192))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 256, 192]             224
              ReLU-2          [-1, 8, 256, 192]               0
              ReLU-3          [-1, 8, 256, 192]               0
           Conv3x3-4          [-1, 8, 256, 192]               0
            Conv2d-5          [-1, 8, 256, 192]             584
           Conv3x3-6          [-1, 8, 256, 192]               0
              ReLU-7          [-1, 8, 256, 192]               0
              ReLU-8          [-1, 8, 256, 192]               0
           Conv5x5-9          [-1, 8, 256, 192]               0
           Conv2d-10           [-1, 8, 128, 96]           1,024
           Conv2d-11          [-1, 16, 128, 96]           1,168
             ReLU-12          [-1, 16, 128, 96]               0
             ReLU-13          [-1, 16, 128, 96]               0
          Conv3x3-14          [-1, 16, 

In [None]:
def train(model, optimizer, criterion, train_dl, device, epochs):
    model.train()
    torch.cuda.empty_cache()

    # Losses & scores
    losses = []
    real_scores = []
    fake_scores = []

    for epoch in range(epochs):
        loss_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for iteration, images, labels in enumerate(tqdm(train_dl), 0):
            # Train discriminator
            # Clear discriminator gradients
            optimizer.zero_grad()

            images = images.to(device)

            # Pass images through discriminator
            preds = model(images)
            loss = criterion(preds, labels)
            if labels == torch.tensor([1]):
                cur_real_score = torch.mean(preds).item()
            else:
                cur_fake_score = torch.mean(preds).item()


            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)
            #gp = gradient_penalty("discriminator", real_images, fake_images, device)

            # Update discriminator weights
            loss.backward()
            optimizer.step()
            loss_per_epoch.append(loss.item())

        # Record losses & scores
        losses.append(np.mean(loss_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))

    return losses, real_scores, fake_scores