In [1]:
import os
import time

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [2]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, img_dir):
        img_dir = d_path + "/" + img_dir + "/"
        
        path_list = os.listdir(img_dir)
        abspath = os.path.abspath(img_dir) 
        
        self.img_dir = img_dir
        self.img_list = [os.path.join(abspath, path) for path in path_list]

        self.transform = transforms.Compose([
            transforms.Resize([img_size,img_size]),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 
        ])
        
    def __len__(self):
        return len(self.img_list)


    def __getitem__(self, idx):
        path = self.img_list[idx]
        img = Image.open(path).convert('RGB')

        img_tensor = self.transform(img)
        return img_tensor

In [4]:
d_path = r"C:\Users\Vikash kumar singh\Desktop\dataset\selfie2anime"
A_dataset = "trainA"
B_dataset = "trainB"
img_size = 128
batch_size = 32

In [5]:
a_dataset = Dataset(A_dataset)
b_dataset = Dataset(B_dataset)

data_loader_a = DataLoader(a_dataset, batch_size, shuffle=True)
data_loader_b = DataLoader(b_dataset, batch_size, shuffle=True)

In [6]:
class Discriminator(nn.Module):
    def __init__(self,conv_dim=32):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3,conv_dim,kernel_size=4,padding=1,stride=2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(conv_dim, conv_dim*2, 4, stride=2, padding=1),
            nn.InstanceNorm2d(conv_dim*2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(conv_dim*2, conv_dim*4, 4, stride=2, padding=1),
            nn.InstanceNorm2d(conv_dim*4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(conv_dim*4, conv_dim*8, 4, padding=1),
            nn.InstanceNorm2d(conv_dim*8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(conv_dim*8, 1, 4, padding=1),
        )

    def forward(self, x):
        x = self.main(x)
        x = F.avg_pool2d(x, x.size()[2:])
        x = torch.flatten(x, 1)
        return x

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()

        self.main = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels, in_channels, 3),
            nn.InstanceNorm2d(in_channels)
        )

    def forward(self, x):
        return x + self.main(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, conv_dim=64, n_res_block=9):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ReflectionPad2d(3),
            nn.Conv2d(3, conv_dim, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),

            nn.Conv2d(conv_dim, conv_dim*2, 3, stride=2, padding=1),
            nn.InstanceNorm2d(conv_dim*2),
            nn.ReLU(inplace=True),
            nn.Conv2d(conv_dim*2, conv_dim*4, 3, stride=2, padding=1),
            nn.InstanceNorm2d(conv_dim*4),
            nn.ReLU(inplace=True),

            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),
            ResidualBlock(conv_dim*4),

            nn.ConvTranspose2d(conv_dim*4, conv_dim*2, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(conv_dim*2),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(conv_dim*2, conv_dim, 3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(conv_dim),
            nn.ReLU(inplace=True),

            nn.ReflectionPad2d(3),
            nn.Conv2d(conv_dim, 3, 7),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

In [9]:
g_conv_dim = 64
d_conv_dim = 64
n_res_block = 6  

In [13]:
def load_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path, weights_only=True)
    model.load_state_dict(checkpoint)
    model.eval()  
    return model

G_A2B_path = r'G_AtoB_epoch_20.pth'  
G_B2A_path = r'G_BtoA_epoch_20.pth'  
D_A_path = r'D_A_epoch_20.pth'     
D_B_path = r'D_B_epoch_20.pth'

In [14]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device("cpu")

G_AtoB = Generator(conv_dim=g_conv_dim, n_res_block=n_res_block).to(device)
G_BtoA = Generator(conv_dim=g_conv_dim, n_res_block=n_res_block).to(device)

D_A = Discriminator(conv_dim=d_conv_dim).to(device)
D_B = Discriminator(conv_dim=d_conv_dim).to(device)

G_AtoB = load_model(G_AtoB, G_A2B_path)
G_BtoA = load_model(G_BtoA, G_B2A_path)
D_A = load_model(D_A, D_A_path)
D_B = load_model(D_B, D_B_path)

In [15]:
def real_mse_loss(D_out):
    return torch.mean((D_out - 1) ** 2)

def fake_mse_loss(D_out):
    return torch.mean(D_out ** 2)

def cycle_consistency_loss(real_img, reconstructed_img, lambda_weight):
    reconstr_loss = torch.mean(torch.abs(real_img - reconstructed_img))
    return lambda_weight * reconstr_loss

In [12]:
def train_generator(images_a,images_b,opt_g):
    opt_g.zero_grad()
    fake_images_a = G_BtoA(images_b)
    d_real_a = D_A(fake_images_a)
    g_BtoA_loss = real_mse_loss(d_real_a)

    recon_b = G_AtoB(fake_images_a)
    recon_b_loss = cycle_consistency_loss(images_b, recon_b, lambda_weight=10)


    fake_images_b = G_AtoB(images_a)

    d_real_b = D_B(fake_images_b)
    g_AtoB_loss = real_mse_loss(d_real_b)

    recon_a = G_BtoA(fake_images_b)
    recon_a_loss = cycle_consistency_loss(images_a, recon_a, lambda_weight=10)

    g_total_loss = g_BtoA_loss + g_AtoB_loss + recon_b_loss + recon_a_loss
    g_total_loss.backward()
    opt_g.step()

    return g_total_loss.item()

In [13]:
def train_discriminator(images_a,images_b,opt_d_a,opt_b):
    opt_d_a.zero_grad()

    d_real_a = D_A(images_a)
    d_real_loss_a = real_mse_loss(d_real_a)
    
    fake_images_a = G_BtoA(images_b)

    d_fake_a = D_A(fake_images_a)
    d_fake_loss_a = fake_mse_loss(d_fake_a)
    
    d_a_loss = d_real_loss_a + d_fake_loss_a
    d_a_loss.backward()
    opt_d_a.step()

    opt_d_b.zero_grad()
        
    d_real_b = D_B(images_b)
    d_real_loss_b = real_mse_loss(d_real_b)

    fake_images_b = G_AtoB(images_a)

    d_fake_b = D_B(fake_images_b)
    d_fake_loss_b = fake_mse_loss(d_fake_b)

    d_b_loss = d_real_loss_b + d_fake_loss_b
    d_b_loss.backward()
    opt_d_b.step()

    return d_a_loss.item(), d_b_loss.item()


In [28]:
import os
import torch
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt

def denorm(tensor):
    return (tensor + 1) / 2  

def save_samples_cyclegan(index, real_A, real_B, fake_A, fake_B, show=True,sample_dir = r"C:\Users\Vikash kumar singh\Desktop\New folder\cynaptics_induc\induction-task\selfie2anime_output"):
    os.makedirs(sample_dir, exist_ok=True)

    fake_A_fname = 'fake_A-{0:0=4d}.png'.format(index)
    fake_B_fname = 'fake_B-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_A), os.path.join(sample_dir, fake_A_fname), nrow=8)
    save_image(denorm(fake_B), os.path.join(sample_dir, fake_B_fname), nrow=8)

    real_A_fname = 'real_A-{0:0=4d}.png'.format(index)
    real_B_fname = 'real_B-{0:0=4d}.png'.format(index)
    save_image(denorm(real_A), os.path.join(sample_dir, real_A_fname), nrow=8)
    save_image(denorm(real_B), os.path.join(sample_dir, real_B_fname), nrow=8)

    print('Saving', fake_A_fname, fake_B_fname, real_A_fname, real_B_fname)

    if show:
        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        axes = axes.ravel()

        axes[0].set_title("Real A")
        axes[0].imshow(make_grid(denorm(real_A).cpu(), nrow=8).permute(1, 2, 0))
        axes[0].axis('off')

        axes[1].set_title("Fake B")
        axes[1].imshow(make_grid(denorm(fake_B).cpu(), nrow=8).permute(1, 2, 0))
        axes[1].axis('off')

        axes[2].set_title("Real B")
        axes[2].imshow(make_grid(denorm(real_B).cpu(), nrow=8).permute(1, 2, 0))
        axes[2].axis('off')

        axes[3].set_title("Fake A")
        axes[3].imshow(make_grid(denorm(fake_A).cpu(), nrow=8).permute(1, 2, 0))
        axes[3].axis('off')

        plt.tight_layout()
        plt.show()

In [15]:
def train(epochs,dataloader_a,dataloader_b,opt_g,opt_d_a,opt_d_b):
    losses = []
    loss_g_min = np.Inf

    start_time = time.time()

    for epoch in range(epochs):
        for(images_a,images_b) in zip(dataloader_a,dataloader_b):
            images_a,images_b = images_a.to(device), images_b.to(device)
            loss_g = train_generator(images_a,images_b,opt_g)
            loss_d_a,loss_d_b = train_discriminator(images_a,images_b,opt_d_a,opt_d_b)
        
        losses.append((loss_g,loss_d_a,loss_d_b))
        end_time = time.time()
        total_duration = end_time - start_time
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d_a: {:.4f},loss_d_b: {:.4f}, duration: {:.4f}".format(
        epoch+1, epochs, loss_g, loss_d_a, loss_d_b, total_duration))
        
        if (epoch + 1) % 5 == 0:
            with torch.no_grad():  
                fake_B = G_AtoB(images_a)  
                fake_A = G_BtoA(images_b)  

                save_samples_cyclegan(epoch + 1, images_a, images_b, fake_A, fake_B, show=True)
                print(f"Saved generated images for epoch {epoch + 1}")
            torch.save(G_AtoB.state_dict(), f"G_AtoB_epoch_{epoch + 1}.pth")
            torch.save(G_BtoA.state_dict(), f"G_BtoA_epoch_{epoch + 1}.pth")
            torch.save(D_A.state_dict(), f"D_A_epoch_{epoch + 1}.pth")
            torch.save(D_B.state_dict(), f"D_B_epoch_{epoch + 1}.pth")

        if loss_g < loss_g_min:
            loss_g_min = loss_g
            
            torch.save(G_AtoB.state_dict(), "G_AtoB_new")
            torch.save(G_BtoA.state_dict(), "G_BtoA_new")
            
            torch.save(D_A.state_dict(), "D_A_new")
            torch.save(D_B.state_dict(), "D_B_new")
            
            print("Models Saved")
            
    torch.save(G_AtoB.state_dict(), "G_AtoB_new_last")
    torch.save(G_BtoA.state_dict(), "G_BtoA_new_last")
    
    torch.save(D_A.state_dict(), "D_A_new_last")
    torch.save(D_B.state_dict(), "D_B_new_last")
    return losses

In [16]:
lr = 0.00002
epochs = 20

g_params = list(G_AtoB.parameters()) + list(G_BtoA.parameters())
opt_g = optim.Adam(g_params, lr, betas=(0.5, 0.999))
opt_d_a = optim.Adam(D_A.parameters(), lr, betas=(0.5, 0.999))
opt_d_b =  optim.Adam(D_B.parameters(), lr, betas=(0.5, 0.999))

In [None]:
history = train(epochs,data_loader_a,data_loader_b,opt_g,opt_d_a,opt_d_b)

In [18]:
torch.save(G_AtoB.state_dict(), "G_AtoB_new_last")
torch.save(G_BtoA.state_dict(), "G_BtoA_new_last")

torch.save(D_A.state_dict(), "D_A_new_last")
torch.save(D_B.state_dict(), "D_B_new_last")

In [21]:
d_path = r"C:\Users\Vikash kumar singh\Desktop\dataset\selfie2anime"

A_test_dataset = "testA"
B_test_dataset = "testB"

img_size = 128
batch_size = 4

a__test_dataset = Dataset(A_test_dataset)
b_test_dataset = Dataset(B_test_dataset)

data_loader_test_a = DataLoader(a__test_dataset, batch_size, shuffle=False)
data_loader_test_b = DataLoader(b_test_dataset, batch_size, shuffle=False)

In [None]:
for i, (real_A, real_B) in enumerate(zip(data_loader_test_a, data_loader_test_b)):
    real_A = real_A.to(device) 
    real_B = real_B.to(device) 

    with torch.no_grad():
        fake_B = G_AtoB(real_A)  
        fake_A = G_BtoA(real_B)  

    save_samples_cyclegan(i, real_A, real_B, fake_A, fake_B, show=True)

In [None]:
from torchvision import transforms
from PIL import Image

transform = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),          
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])
  
real_A_image = Image.open(r"C:\Users\Vikash kumar singh\Pictures\Screenshots\Screenshot 2025-01-16 192822.png").convert("RGB")  
real_B_image = Image.open(r"C:\Users\Vikash kumar singh\Pictures\4365816.jpg").convert("RGB")

real_A = transform(real_A_image).unsqueeze(0).to(device)  
real_B = transform(real_B_image).unsqueeze(0).to(device)  

with torch.no_grad():
    fake_B = G_AtoB(real_A)  
    fake_A = G_BtoA(real_B)  

save_samples_cyclegan(0, real_A, real_B, fake_A, fake_B, show=True,sample_dir=r"C:\Users\Vikash kumar singh\Pictures")