In [None]:
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install tensorboard
# !pip install albumentations
# !pip install matplotlib
#!pip install torchsummary
#!pip install streamlit
#!pip install tqdm

In [None]:
import torch 
import torch .nn as nn 
import torch .optim as optim 
import torchvision
import torchvision.datasets as datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader , Dataset
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from tqdm import tqdm
from models import Dicriminator , Genrator , init_weights
import os
import numpy as np
import matplotlib.pyplot as plt

import pandas as pd


In [None]:
data = pd.read_excel('fashion_data_clean.xlsx')

In [None]:
transform = A.Compose(
    [A.Resize(width=256 , height=256),
    A.Normalize(mean=[0.5 , 0.5 , 0.5] , std=[0.5 , 0.5 ,0.5], max_pixel_value=255.0),
    ToTensorV2()
    
    ],
    additional_targets={"image1":"image"}
)

In [None]:

class Mydata(Dataset):
    def __init__(self, data,transform ):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        input_path = self.data.iloc[index]['input_image']
        output_path = self.data.iloc[index]['output_image']
        label = self.data.iloc[index]['label']
        label_maped = {'jeans': 0, 'suit': 1, 'jacket': 2, 't shirt': 3, 'hoodie': 4, 'trouser': 5}
        label = label_maped[label]
        label=torch.tensor(label , dtype=torch.long)

        input_image = plt.imread(input_path)
        output_image = plt.imread(output_path)

        augmentation = self.transform(image=input_image , image1=output_image)
        input_image , output_image = augmentation["image"] , augmentation["image1"]
        return input_image , output_image , label



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8
lr = 2e-4
z_dim = 100
channels = 3
filters = 64
epochs = 1000 
writer_input = SummaryWriter(f"log/input")
writer_fake = SummaryWriter(f"log/fake")
writer_real = SummaryWriter(f"log/real")
num_label = 6
l1_lambda =100
img_size = 256


In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):

    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [None]:
def load_checkpoint(checkpoint_file, model, optimizer):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

In [None]:
gen = Genrator (channels ,num_label,img_size, filters).to(device)
disc=Dicriminator(channels,num_label,img_size).to(device)
init_weights(disc)
init_weights(gen)
opt_gen = optim.Adam(gen.parameters(),lr=lr ,betas=(0.5,0.999))
opt_disc = optim.Adam(disc.parameters(),lr=lr ,betas=(0.5,0.999))
criterion = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

In [None]:
my_data = Mydata(data,transform)
data_loder = DataLoader(my_data , batch_size=batch_size , shuffle=True , num_workers=2)


In [None]:
for epoch in range(epochs):
    
    
    for batch_idx, (input_image , output_image , label) in enumerate (tqdm(data_loder , leave=True)):
        input_image = input_image.to(device)
        output_image = output_image.to(device)
        label = label.to(device)

        ## trian disc 
        fake = gen(input_image, label)
        disc_real = disc(input_image , output_image,label)
        disc_fake = disc(input_image , fake.detach(),label)
        disc_real_loss = criterion(disc_real ,torch.ones_like(disc_real))
        disc_fake_loss = criterion(disc_fake ,torch.zeros_like(disc_fake))
        disc_loss = (disc_real_loss +disc_fake_loss)/2

        disc.zero_grad()
        disc_loss.backward()
        opt_disc.step()


        ##trian generator
        output = disc(input_image , fake,label)
        fake_loss = criterion(output ,torch.ones_like(output))
        l1 = l1_loss(fake , output_image)*l1_lambda
        gen_loss = fake_loss+l1 
 
        gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        
    if epoch%50 ==0 :
        with torch.no_grad():
            fake = gen(input_image, label)
            img_grid_real = torchvision.utils.make_grid(output_image[:4], normalize=True)
            img_grid_fake = torchvision.utils.make_grid(fake[:4], normalize=True)
            img_grid_input = torchvision.utils.make_grid(input_image[:4], normalize=True)
            writer_real.add_image("Real", img_grid_real ,global_step = epoch)
            writer_fake.add_image("Fake", img_grid_fake,global_step = epoch)
            writer_input.add_image("Real", img_grid_input ,global_step = epoch)
            writer_input.add_image("Fake", img_grid_input,global_step = epoch)
            save_checkpoint(gen, opt_gen, filename="gen_checkpoint.pth.tar")
            save_checkpoint(disc, opt_disc, filename="disc_checkpoint.pth.tar")

    

