In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
#TRAIN_DIR = "data/train"
#VAL_DIR = "data/val"
BATCH_SIZE = 1
LEARNING_RATE = 1e-6
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = 4
NUM_EPOCHS = 50
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN_H = "D:\\Personal Research Dataset\\GAN\\genh.pth.tar"
CHECKPOINT_GEN_C = "D:\\Personal Research Dataset\\GAN\\genz.pth.tar"
CHECKPOINT_CRITIC_H = "D:\\Personal Research Dataset\\GAN\\critich.pth.tar"
CHECKPOINT_CRITIC_C = "D:\\Personal Research Dataset\\GAN\\criticz.pth.tar"

In [None]:
!pip install torchvision
!pip install albumentations

In [None]:
import torch

torch.cuda.is_available()



In [None]:
import tensorflow as tf

In [None]:
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=1.0),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np

class Dset(Dataset):
    def __init__(self, root_Common_Scab, root_Healthy, transform=None):
        self.root_Common_Scab = root_Common_Scab
        self.root_Healthy = root_Healthy
        self.transform = transform

        self.Common_Scab_images = os.listdir(root_Common_Scab)
        self.Healthy_images = os.listdir(root_Healthy)
        self.length_dataset = max(len(self.Common_Scab_images), len(self.Healthy_images)) # 1000, 1500
        self.Common_Scab_len = len(self.Common_Scab_images)
        self.Healthy_len = len(self.Healthy_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        Common_Scab_img = self.Common_Scab_images[index % self.Common_Scab_len]
        Healthy_img = self.Healthy_images[index % self.Healthy_len]

        Common_Scab_path = os.path.join(self.root_Common_Scab, Common_Scab_img)
        Healthy_path = os.path.join(self.root_Healthy, Healthy_img)

        Common_Scab_img = np.array(Image.open(Common_Scab_path).convert("RGB"))
        Healthy_img = np.array(Image.open(Healthy_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=Common_Scab_img, image0=Healthy_img)
            Common_Scab_img = augmentations["image"]
            Healthy_img = augmentations["image0"]

        return Common_Scab_img, Healthy_img

In [None]:

import torch.nn as nn


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                4,
                stride,
                1,
                bias=True,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, stride=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,
                1,
                kernel_size=4,
                stride=1,
                padding=1,
                padding_mode="reflect",
            )
        )
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))


def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)


if __name__ == "__main__":
    test()

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity(),
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)


class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                img_channels,
                num_features,
                kernel_size=7,
                stride=1,
                padding=3,
                padding_mode="reflect",
            ),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features, num_features * 2, kernel_size=3, stride=2, padding=1
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 4,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                ),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features * 4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(
                    num_features * 4,
                    num_features * 2,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
                ConvBlock(
                    num_features * 2,
                    num_features * 1,
                    down=False,
                    kernel_size=3,
                    stride=2,
                    padding=1,
                    output_padding=1,
                ),
            ]
        )

        self.last = nn.Conv2d(
            num_features * 1,
            img_channels,
            kernel_size=7,
            stride=1,
            padding=3,
            padding_mode="reflect",
        )

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))


def test():
    img_channels = 3
    img_size = 96
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)


if __name__ == "__main__":
    test()

In [None]:
import torch

import sys
#from utils import save_checkpoint, load_checkpoint
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
#import config
from tqdm import tqdm
from torchvision.utils import save_image
#from discriminator_model import Discriminator
#from generator_model import Generator


In [None]:
def train_fn(
    disc_H, disc_C, gen_C, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler
):
    H_reals = 0
    H_fakes = 0
    loop = tqdm(loader, leave=True)

    for idx, (Common_Scab, Healthy) in enumerate(loop):
        Common_Scab = Common_Scab.to(DEVICE)
        Healthy = Healthy.to(DEVICE)

        # Train Discriminators H and Z
        with torch.cuda.amp.autocast():
            fake_Healthy = gen_H(Common_Scab)
            D_H_real = disc_H(Healthy)
            D_H_fake = disc_H(fake_Healthy.detach())
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_Common_Scab = gen_C(Healthy)
            D_C_real = disc_C(Common_Scab)
            D_C_fake = disc_C(fake_Common_Scab.detach())
            D_C_real_loss = mse(D_C_real, torch.ones_like(D_C_real))
            D_C_fake_loss = mse(D_C_fake, torch.zeros_like(D_C_fake))
            D_C_loss = D_C_real_loss + D_C_fake_loss

            # put it togethor
            D_loss = (D_H_loss + D_C_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train Generators H and Z
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_Healthy)
            D_C_fake = disc_C(fake_Common_Scab)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_C = mse(D_C_fake, torch.ones_like(D_C_fake))

            # cycle loss
            cycle_Common_Scab = gen_C(fake_Healthy)
            cycle_Healthy = gen_H(fake_Common_Scab)
            cycle_Common_Scab_loss = l1(Common_Scab, cycle_Common_Scab)
            cycle_Healthy_loss = l1(Healthy, cycle_Healthy)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_Common_Scab = gen_C(Common_Scab)
            identity_Healthy = gen_H(Healthy)
            identity_Common_Scab_loss = l1(Common_Scab, identity_Common_Scab)
            identity_Healthy_loss = l1(Healthy, identity_Healthy)

            # add all togethor
            G_loss = (
                loss_G_C
                + loss_G_H
                + cycle_Common_Scab_loss * LAMBDA_CYCLE
                + cycle_Healthy_loss * LAMBDA_CYCLE
                + identity_Healthy_loss * LAMBDA_IDENTITY
                + identity_Common_Scab_loss * LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        
        save_image(fake_Healthy * 0.5 + 0.5, f"C:\\Users\\Mukaffi\\Desktop\\Data\\common_{idx}.png")
        save_image(fake_Common_Scab * 0.5 + 0.5, f"C:\\Users\\Mukaffi\\Desktop\\Data\\healthy_{idx}.png")

        loop.set_postfix(H_real=H_reals / (idx + 1), H_fake=H_fakes / (idx + 1))



# save Model

In [None]:
import random, torch, os, numpy as np
import torch.nn as nn
import copy

def save_checkpoint(model, optimizer, filename="D:\\Datasets\\model.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

# load_checkpoint

In [None]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
def main():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_C = Discriminator(in_channels=3).to(DEVICE)
    gen_C = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    gen_H = Generator(img_channels=3, num_residuals=9).to(DEVICE)
    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_C.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(gen_C.parameters()) + list(gen_H.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN_H,
            gen_H,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_GEN_C,
            gen_C,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_H,
            disc_H,
            opt_disc,
            LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_CRITIC_C,
            disc_C,
            opt_disc,
            LEARNING_RATE,
        )

    dataset = Dset(
        root_Healthy="D:\\Datasets\\Thesis Data Train\\trainB", 
        root_Common_Scab="C:\\Users\\Mukaffi\\Desktop\\Potato\\Train\\Common_Scab", 
        transform=transforms,
    )
    val_dataset = Dset(
        root_Healthy="D:\\Datasets\\Thesis Data validation\\trainB", 
        root_Common_Scab="C:\\Users\\Mukaffi\\Desktop\\Potato\\Train\\Common_Scab",
        transform=transforms,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        pin_memory=True,
    )
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(1):
        print("No of epoch:")
        print(epoch)
        train_fn(
            disc_H,
            disc_C,
            gen_C,
            gen_H,
            loader,
            opt_disc,
            opt_gen,
            L1,
            mse,
            d_scaler,
            g_scaler,
        )        
        if SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_C, opt_gen, filename=CHECKPOINT_GEN_C)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_C, opt_disc, filename=CHECKPOINT_CRITIC_C)




if __name__ == "__main__":
    main()