# Constants that need to be set

In [1]:
MODEL_SAVE_PATH = '/content/drive/MyDrive/pwr/sztuczna'
DATASET_PATH = '/content/drive/MyDrive/pwr/sztuczna/train'
TENSORBOARD_LOG_PATH = '/content/drive/MyDrive/pwr/sztuczna/logs'
EPOCHS = 30
BATCH_SIZE = 16

# Imports

In [2]:
import os
import torch
import sys
import numpy as np

from torch import nn, optim
from PIL import Image
from skimage.color import rgb2lab
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter

# Define classes (model & dataset)

In [3]:
class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        # First block coding
        self.b1_conv2d = nn.Conv2d(1, 25, kernel_size=4, stride=2, padding=1, bias=False)

        # Second block coding
        self.b2_leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.b2_conv2d = nn.Conv2d(25, 50, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b2_batch_norm2d_1 = nn.BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Third block coding
        self.b3_leaky_relu =  nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.b3_conv2d = nn.Conv2d(50, 100, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b3_batch_norm2d_1 = nn.BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Fourth block coding
        self.b4_leaky_relu =  nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.b4_conv2d = nn.Conv2d(100, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b4_batch_norm2d_1 = nn.BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Middle block
        self.mb_leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        self.mb_conv2d = nn.Conv2d(200, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.mb_relu = nn.ReLU(inplace=True)
        self.mb_conv_transpose2d = nn.ConvTranspose2d(200, 200, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.mb_batch_norm2d = nn.BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Fourth block decoding
        self.b4_relu = nn.ReLU(inplace=True)
        self.b4_conv_transpose2d = nn.ConvTranspose2d(400, 100, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b4_batch_norm2d_2 = nn.BatchNorm2d(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Third block decoding
        self.b3_relu = nn.ReLU(inplace=True)
        self.b3_conv_transpose2d = nn.ConvTranspose2d(200, 50, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b3_batch_norm2d_2 = nn.BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Second block decoding
        self.b2_relu = nn.ReLU(inplace=True)
        self.b2_conv_transpose2d = nn.ConvTranspose2d(100, 25, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.b2_batch_norm2d_2 = nn.BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        # Third block decoding
        self.b1_relu = nn.ReLU(inplace=True)
        self.b1_conv_transpose2d = nn.ConvTranspose2d(50, 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        self.b1_batch_norm2d = nn.BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

        def init_weights(layer):
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)

        # Init weights
        self.apply(init_weights)
    
    def forward(self, x):
        # First block coding
        x = self.b1_conv2d(x)
        b1_connection = x

        # Second block coding
        x = self.b2_leaky_relu(x)
        x = self.b2_conv2d(x)
        x = self.b2_batch_norm2d_1(x)
        b2_connection = x

        # Third block coding
        x = self.b3_leaky_relu(x)
        x = self.b3_conv2d(x)
        x = self.b3_batch_norm2d_1(x)
        b3_connection = x

        # Fourth block coding
        x = self.b4_leaky_relu(x)
        x = self.b4_conv2d(x)
        x = self.b4_batch_norm2d_1(x)
        b4_connection = x

        # Middle block
        x = self.mb_leaky_relu(x)
        x = self.mb_conv2d(x)
        x = self.mb_relu(x)
        x = self.mb_conv_transpose2d(x)
        x = self.mb_batch_norm2d(x)

        # Fourth block decoding
        x = self.b4_relu(torch.cat([b4_connection, x], 1))
        x = self.b4_conv_transpose2d(x)
        x = self.b4_batch_norm2d_2(x)

        # Third block decoding
        x = self.b3_relu(torch.cat([b3_connection, x], 1))
        x = self.b3_conv_transpose2d(x)
        x = self.b3_batch_norm2d_2(x)

        # Second block decoding
        x = self.b2_relu(torch.cat([b2_connection, x], 1))
        x = self.b2_conv_transpose2d(x)
        x = self.b2_batch_norm2d_2(x)

        # First block decoding
        x = self.b1_relu(torch.cat([b1_connection, x], 1))
        x = self.b1_conv_transpose2d(x)
        x = self.b1_batch_norm2d(x)

        return x


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
          nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=True),
          nn.LeakyReLU(0.2, True),
          nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(0.2, True),
          nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
          nn.BatchNorm2d(256),
          nn.LeakyReLU(0.2, True),
          nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=False),
          nn.BatchNorm2d(512),
          nn.LeakyReLU(0.2, True),
          nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1, bias=True)
        )
        
        # Initialize weights
        for layer in self.model:
            if isinstance(layer, nn.Conv2d):
                nn.init.xavier_uniform_(layer.weight)
                if layer.bias is not None:
                    nn.init.zeros_(layer.bias)
                

    def forward(self, x):
        return self.model(x)
        
class ColorizationDataset(Dataset):
    def __init__(self):
        self.paths = [os.path.join(DATASET_PATH, file) for file in os.listdir(DATASET_PATH)]
    
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = np.array(img)
        img_lab = rgb2lab(img).astype("float32")
        img_lab = transforms.ToTensor()(img_lab)
        L = img_lab[[0], ...] / 50. - 1.
        ab = img_lab[[1, 2], ...] / 110.
        return {'L': L, 'ab': ab}
    
    def __len__(self):
        return len(self.paths)

# Init traning

In [4]:
writer = SummaryWriter(TENSORBOARD_LOG_PATH)

dataset = ColorizationDataset()
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Unet().to(device)
discriminator = Discriminator().to(device)
discriminator_true_output = torch.tensor(1.0).to(device)
discriminator_false_output = torch.tensor(0.0).to(device)

loss_function = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

generator_optimizer = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))



# Train

In [5]:
for e in range(EPOCHS):
    i = 0
    generator_loss_sum = 0
    generator_loss_count = 0
    discriminator_loss_sum = 0
    discriminator_loss_count = 0
    for data in data_loader:
        # Fetch image channels from data
        L = data['L'].to(device)
        ab = data['ab'].to(device)

        # Train discriminator
        fake_color = generator(L)
        discriminator.train()
        for param in discriminator.parameters():
            param.requires_grad = True
        discriminator_optimizer.zero_grad()
        fake_image = torch.cat([L, fake_color], dim=1)
        fake_preds = discriminator(fake_image.detach())
        discriminator_loss_fake = loss_function(fake_preds, discriminator_true_output.expand_as(fake_preds))
        real_image = torch.cat([L, ab], dim=1)
        real_preds = discriminator(real_image)
        discriminator_loss_real = loss_function(real_preds, discriminator_false_output.expand_as(real_preds))
        discriminator_loss = (discriminator_loss_fake + discriminator_loss_real) * 0.5
        discriminator_loss.backward()
        discriminator_loss_sum += discriminator_loss.item() * L.size(0)
        discriminator_loss_count += L.size(0)
        discriminator_optimizer.step()

        # Train Generator
        generator.train()
        for param in discriminator.parameters():
            param.requires_grad = False
        generator_optimizer.zero_grad()
        fake_image = torch.cat([L, fake_color], dim=1)
        fake_preds = discriminator(fake_image)
        generator_loss = loss_function(fake_preds, discriminator_false_output.expand_as(fake_preds))
        generator_l1_loss = l1_loss(fake_color, ab) * 100
        generator_loss = generator_loss + generator_l1_loss
        generator_loss.backward()
        generator_loss_sum += generator_loss.item() * L.size(0)
        generator_loss_count += L.size(0)
        generator_optimizer.step()
        
        # Print train state every 32 steps
        i += 1
        if i % 32 != 0:
          continue;
        print(f"\nEpoch {e+1}/{EPOCHS}")
        print(f"Iteration {i}/{len(data_loader)}")
        # Log current loss and epoch avg loss
        print(f"Loss gen {generator_loss.item():.4f} ({(generator_loss_sum / generator_loss_count):.4f})")
        print(f"Loss dis {discriminator_loss.item():.4f} ({(discriminator_loss_sum / discriminator_loss_count):.4f})")

    # Log losses
    writer.add_scalar("Loss/generator", (generator_loss_sum / generator_loss_count), e)
    writer.add_scalar("Loss/discriminator", (discriminator_loss_sum / discriminator_loss_count), e)
    writer.flush()

[1;30;43mStrumieniowane dane wyjściowe obcięte do 5000 ostatnich wierszy.[0m

Epoch 20/30
Iteration 1472/3040
Loss gen 11.7354 (14.1853)
Loss dis 0.4102 (0.1716)

Epoch 20/30
Iteration 1504/3040
Loss gen 11.5069 (14.1832)
Loss dis 0.1444 (0.1706)

Epoch 20/30
Iteration 1536/3040
Loss gen 14.4699 (14.1855)
Loss dis 0.0041 (0.1695)

Epoch 20/30
Iteration 1568/3040
Loss gen 12.7917 (14.1849)
Loss dis 0.2393 (0.1677)

Epoch 20/30
Iteration 1600/3040
Loss gen 13.3730 (14.1776)
Loss dis 0.0223 (0.1669)

Epoch 20/30
Iteration 1632/3040
Loss gen 14.6770 (14.1909)
Loss dis 0.0223 (0.1676)

Epoch 20/30
Iteration 1664/3040
Loss gen 12.7604 (14.1810)
Loss dis 0.0994 (0.1688)

Epoch 20/30
Iteration 1696/3040
Loss gen 10.7518 (14.1903)
Loss dis 0.2373 (0.1676)

Epoch 20/30
Iteration 1728/3040
Loss gen 14.5244 (14.1875)
Loss dis 0.2682 (0.1681)

Epoch 20/30
Iteration 1760/3040
Loss gen 16.6604 (14.1819)
Loss dis 0.0187 (0.1703)

Epoch 20/30
Iteration 1792/3040
Loss gen 10.2937 (14.1690)
Loss dis 0.

# Save model

In [6]:
torch.save(generator.state_dict(), MODEL_SAVE_PATH + '/model.pt')