In [2]:
import torch
import torchvision.transforms as transforms
from torchvision.utils import make_grid

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F

import random



In [3]:
# data path
dat_dir = ''


# data config
img_h = 256
img_b = 256
channels = 3


# configs
epochs = 0
n_epochs = 5
batch_size = 1
lr = 0.0002
b1 = 0.5
b2 = 0.999
decay_epoch = 3


In [4]:
class ResidualBLock(nn.Module):

    def __init__(self , in_f):
        super(ResidualBLock , self).__init__()

        self.block = nn.Sequential(
            nn.ReflectionPad2d(1) , 
            nn.Conv2d(in_f , in_f , 3) , 
            nn.InstanceNorm2d(in_f) , 
            # inplace  = True for memory efficiency
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(1) , 
            nn.Conv2d(in_f , in_f , 3), 
            nn.InstanceNorm2d(in_f)
        )


    def forward(self , x):
        return x + self.block(x)
    


class GeneratorResNet(nn.Module):

    def __init__(self , input_shape , num_res_block):

        super(GeneratorResNet , self).__init__()


        channels = input_shape[0]

        out_f = 64

        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels , out_f , 7),
            nn.InstanceNorm2d(out_f),   #helps to learn features which are invariate to brightness and contrast
            nn.ReLU(inplace=True)
            
        ]

        in_f = out_f


        # downsampling

        for _ in range(2):

            out_f *= 2
            model += [
                nn.Conv2d(in_f , out_f , 3,stride=2 , padding=1), 
                nn.InstanceNorm2d(out_f),
                nn.ReLU(inplace=True)
            ]

            in_f = out_f


        # Residual BLock
        for _ in range(num_res_block):
            model += [ResidualBLock(out_f)]


        # upsampling

        for _ in range(2):

            out_f //= 2
            model += [
                nn.Upsample(scale_factor=2), # --> this means width * 2 , height * 2
                nn.Conv2d(in_f, out_f , 3 , stride=1 , padding= 1),
                nn.ReLU(inplace=True)

            ]

            in_f = out_f
        # output lapyer


        model += [
            nn.ReflectionPad2d(channels) , 
            nn.Conv2d(out_f  , channels , 7) ,
            nn.Tanh()
            ]
        
        # unpacking

        self.model = nn.Sequential(*model)

    

    def forward(self , x):
        return self.model(x)
    



In [5]:
class Discriminator(nn.Module):

    def __init__(self , inp_size):

        super(Discriminator , self).__init__()


        channels , height , width = inp_size


        self.output_shape = (1,  height//2**4 , width // 2**4)


        def discriminator_bloc(in_filters , out_filters , normalize = True):
            layers = [
                nn.Conv2d(in_filters , out_filters , 4 , stride=2 , padding=1)]
            
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))

            layers.append(nn.LeakyReLU(0.2 , inplace=True))

            return layers
        
        self.model = nn.Sequential(
            *discriminator_bloc(channels , 64 , normalize=False) ,
            *discriminator_bloc(64 , 128) ,
            *discriminator_bloc(128 , 256) , 
            *discriminator_bloc(256 , 512),
            nn.ZeroPad2d((1,0,1,0)),
            nn.Conv2d(512 , 1 ,4 ,padding=1)
        )


    def forward(self  , img):
         return self.model(img)

In [6]:
# LOSS FUNCTION

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()


In [7]:
input_shape = (channels , img_h , img_b)

n_residual_blocks = 9

G_AB = GeneratorResNet(input_shape=input_shape , num_res_block=n_residual_blocks)
G_BA = GeneratorResNet(input_shape=input_shape , num_res_block=n_residual_blocks)

D_A = Discriminator(inp_size=input_shape)
D_B = Discriminator(inp_size=input_shape)


In [8]:
cuda = torch.cuda.is_available()

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

In [9]:
def weights_init_normal(m):

    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal(m.weight.data ,0.0 , 0.02)

        if hasattr(m , 'bias') and m.bias  is not None:
            torch.nn.init.constant(m.bias.data , 0.0)

        elif classname.find('BatchNorm2d') != -1:
            torch.nn.init.normal(m.weight.data , 1.0 , 0.02)
            torch.nn.init.constant(m.bias.data , 0.0)




In [10]:
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)


  torch.nn.init.normal(m.weight.data ,0.0 , 0.02)
  torch.nn.init.constant(m.bias.data , 0.0)


Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): ZeroPad2d((1, 0, 1, 0))
    (12): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [11]:
def temp_weights_init_normal(m):
    classname =  m.__class__.__name__
    print(classname)

In [12]:
G_AB.apply(temp_weights_init_normal);

ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
Conv2d
InstanceNorm2d
ReLU
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPad2d
Conv2d
InstanceNorm2d
Sequential
ResidualBLock
ReflectionPad2d
Conv2d
InstanceNorm2d
ReLU
ReflectionPa

In [13]:
import itertools

optim_G = torch.optim.Adam(itertools.chain(G_AB.parameters() , G_BA.parameters()) , lr=lr , betas=(b1,b2))

optim_D_A = torch.optim.Adam(D_A.parameters() , lr=lr , betas=(b1,b2))

optim_D_B = torch.optim.Adam(D_B.parameters() , lr=lr , betas=(b1,b2))



In [14]:
class LambdaLR:

    def __init__(self , n_epochs  , offset , decay_start_epoch):
        assert (n_epochs - decay_start_epoch) > 0 , "Decay myst start before the treaining session ends"

        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch


    def step(self , epoch):
        return 1.0 - max(0 , epoch+ self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)
    


In [15]:
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optim_G,
    lr_lambda=LambdaLR(n_epochs, epochs, decay_epoch).step
)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optim_D_A,
    lr_lambda=LambdaLR(n_epochs, epochs, decay_epoch).step
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optim_D_B,
    lr_lambda=LambdaLR(n_epochs, epochs, decay_epoch).step
)

In [16]:
from PIL import Image
import torchvision.transforms as transforms

transforms_ = [
    transforms.Resize(int(img_h*1.12), Image.BICUBIC),
    transforms.RandomCrop((img_h, img_b)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]

In [17]:
def to_rgb(image):
    rgb_image = Image.new("RGB", image.size)
    rgb_image.paste(image)
    return rgb_image

In [19]:
from torch.utils.data import Dataset

class ImageDataset(Dataset):

    def __init__(self , root ,transforms= None , unaligned = False):
        self.traisnform = transforms.Compose(transforms_)

        self.unaligned = unaligned
        

SyntaxError: incomplete input (525429090.py, line 5)

In [None]:
import os
from PIL import Image
from torch.utils.data import Dataset

data_dir = "/home/aman/code/CV/throat_infection/data"
classes = sorted(os.listdir(data_dir))

class_to_label = {class_name: label for label ,class_name in enumerate(classes)}



class CustomImageDataset(Dataset):

    def __init__(self , data_dir , transforms=None):
        self.data_dir = data_dir
        self.transform = transforms
        self.image_paths = []
        self.labels = []

        for class_name , label in class_to_label.items():
            class_dir = os.path.join(data_dir , class_name)

            for image_path in os.listdir(class_dir):
                self.image_paths.append(os.path.join(class_dir , image_path))
                self.labels.append(label)
        
        random.shuffle(self.image_paths)

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self ,idx):
        random.shuffle(self.image_paths)
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image)
        label = self.labels[idx]
        return image
    

