In [1]:
import torch
import torch.nn as nn

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 4, stride = stride, padding = 1, padding_mode = 'reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace = True)
        )

    def forward(self, x):
        return self.layer(x)

In [5]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, out_channels = [64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, out_channels[0], kernel_size = 4, stride = 2, padding = 1),
            nn.LeakyReLU(0.2, inplace = True)
        )
        layers = [Block(out_channels[i], out_channels[i+1], stride = 1 if i == len(out_channels) - 2 else 2) for i in range(len(out_channels) - 1)]
        layers.append(nn.Conv2d(out_channels[-1], 1, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect'))
        self.backbone = nn.Sequential(
            *layers
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.backbone(x)

        return torch.sigmoid(x)

In [6]:
D_model = Discriminator()
test = torch.rand(1,3,256,256)
print(D_model(test).shape)

torch.Size([1, 1, 30, 30])


In [15]:
class ConvBlock(nn.Module):
    def __init__(self, inp, out, k, s, p, down = True, act = True):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(inp, out, kernel_size = k, stride = s, padding = p, padding_mode = 'reflect') if down else nn.ConvTranspose2d(inp, out, kernel_size = k, stride = s, padding = p),
            nn.InstanceNorm2d(out),
            nn.ReLU(inplace = True) if act else nn.Identity()
        )

    def forward(self, x):
        return self.layer(x)

In [29]:
class ResBlock(nn.Module):
    def __init__(self, features):
        super().__init__()
        self.layer = nn.Sequential(
            ConvBlock(features, features, 3, 1, 1, True, True),
            ConvBlock(features, features, 3, 1, 1, True, False)
        )

    def forward(self, x):
        y = self.layer(x)
        return x + y

In [30]:
class Generator(nn.Module):
    def __init__(self, img_channels = 3, num_features = 64, num_residuals = 9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect'),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace = True)
        )

        self.down = nn.Sequential(
            *[ConvBlock(num_features * (1 if i == 0 else 2), num_features * (2 if i == 0 else 4), 3, 2, 1) for i in range(2)]
        )

        self.res_blocks = nn.Sequential(
            *[ResBlock(num_features * 4) for i in range(num_residuals)]            
        )

        self.up = nn.Sequential(
            *[ConvBlock(num_features * 4, num_features * 2, 4, 2, 1, down = False), 
              ConvBlock(num_features * 2, num_features, 4, 2, 1, down = False)]
        )

        self.last = nn.Sequential(
            nn.Conv2d(num_features, img_channels, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect')
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.down(x)
        x = self.res_blocks(x)
        x = self.up(x)
        x = self.last(x)
        return torch.tanh(x)


In [31]:
G_model = Generator()
test = torch.rand(1,3,256,256)
print(G_model(test).shape)

torch.Size([1, 3, 256, 256])


In [None]:
device = torch.device('cuda' if torch.device.is_available() else 'cpu')

In [None]:
class Model():
    def __init__(self):
        self.Gen_A = Generator().to(device)
        self.Gen_B = Generator().to(device)
        self.Disc_A = Discriminator().to(device)
        self.Disc_B = Discriminator().to(device)

        self.Gen_A_opt = torch.optim.Adam(self.Gen_A.parameters(), lr = 0.0001)
        self.Gen_B_opt = torch.optim.Adam(self.Gen_B.parameters(), lr = 0.0001)
        self.Disc_A_opt = torch.optim.Adam(self.Disc_A.parameters(), lr = 0.0001)
        self.Disc_B_opt = torch.optim.Adam(self.Disc_B.parameters(), lr = 0.0001)

        self.criterion = nn.BCELoss()
        self.l1 = nn.L1Loss()

    def train(self, real_black, real_gray):
        # -------------------------
        # Train Discriminator A
        # -------------------------
        fake_gray = self.Gen_A(real_black)
        real_gray_pred = self.Disc_A(real_gray)
        fake_gray_pred = self.Disc_A(fake_gray.detach())

        real_gray_loss = self.criterion(real_gray_pred, torch.ones_like(real_gray_pred))
        fake_gray_loss = self.criterion(fake_gray_pred, torch.zeros_like(fake_gray_pred))
        disc_gray_loss = 0.5 * (real_gray_loss + fake_gray_loss)

        self.Disc_A_opt.zero_grad()
        disc_gray_loss.backward()
        self.Disc_A_opt.step()

        # -------------------------
        # Train Discriminator B
        # -------------------------
        fake_black = self.Gen_B(real_gray)
        real_black_pred = self.Disc_B(real_black)
        fake_black_pred = self.Disc_B(fake_black.detach())

        real_black_loss = self.criterion(real_black_pred, torch.ones_like(real_black_pred))
        fake_black_loss = self.criterion(fake_black_pred, torch.zeros_like(fake_black_pred))
        disc_black_loss = 0.5 * (real_black_loss + fake_black_loss)

        self.Disc_B_opt.zero_grad()
        disc_black_loss.backward()
        self.Disc_B_opt.step()

        # -------------------------
        # Train Generator A (Black→Gray)
        # -------------------------
        gray_adversarial_loss = self.criterion(self.Disc_A(fake_gray), torch.ones_like(self.Disc_A(fake_gray)))
        gray_cycle_loss = self.l1(self.Gen_B(fake_gray), real_black)
        lambda_cycle = 10.0
        gen_gray_loss = gray_adversarial_loss + lambda_cycle * gray_cycle_loss

        self.Gen_A_opt.zero_grad()
        gen_gray_loss.backward()
        self.Gen_A_opt.step()

        # -------------------------
        # Train Generator B (Gray→Black)
        # -------------------------
        black_adversarial_loss = self.criterion(self.Disc_B(fake_black), torch.ones_like(self.Disc_B(fake_black)))
        black_cycle_loss = self.l1(self.Gen_A(fake_black), real_gray)
        gen_black_loss = black_adversarial_loss + lambda_cycle * black_cycle_loss

        self.Gen_B_opt.zero_grad()
        gen_black_loss.backward()
        self.Gen_B_opt.step()

        return {
            "disc_gray_loss": disc_gray_loss.item(),
            "disc_black_loss": disc_black_loss.item(),
            "gen_gray_loss": gen_gray_loss.item(),
            "gen_black_loss": gen_black_loss.item()
        }