In [1]:
import torch
import torchvision 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [2]:
import os 
import cv2
import torch
from torch.utils.data import Dataset,DataLoader
import numpy as np
import glob

In [3]:
from torchvision import datasets, transforms
from torch.autograd import Variable

In [4]:
class CustomDataset(Dataset):
    def __init__(self):
        self.imgs_path = 'gan_data/'
        file_list = glob.glob(self.imgs_path + "*")
        self.data=[]
        for class_path in file_list:
            for img_path,img_H,img_pose in zip(glob.glob(class_path + "/*.jpg"),glob.glob(class_path + "/*).npy"),glob.glob(class_path + "/*pose.npy")):
                for img1_path,img1_H,img1_pose in zip(glob.glob(class_path + "/*.jpg"),glob.glob(class_path + "/*).npy"),glob.glob(class_path + "/*pose.npy")):
                    if img_path != img1_path :
                        self.data.append([img_path,img_H,img_pose,img1_path,img1_H,img1_pose])
        self.img_dim = (72,128)
    def __len__(self):
        return len(self.data)
    def __getitem__(self,idx):
        img_path,z_path,img_pose_path,img1_path,z1_path,img1_pose_path = self.data[idx]
        img = cv2.imread(img_path)
        img = cv2.resize(img,self.img_dim)
        img_tensor = torch.from_numpy(img)
        z_ = np.load(z_path)
        z_= z_.astype(np.float32)
        
        z_tensor = torch.from_numpy(z_)
        img_pose = np.load(img_pose_path)
        img_pose_tensor = torch.from_numpy(img_pose)
        
        img1 = cv2.imread(img1_path)
        img1 = cv2.resize(img1,self.img_dim)
        img1_tensor = torch.from_numpy(img1)
        z1_ = np.load(z1_path)
        z1_= z1_.astype(np.float32)
        z1_tensor = torch.from_numpy(z1_)
        img1_pose = np.load(img1_pose_path)
        img1_pose_tensor = torch.from_numpy(img1_pose)
        return img_tensor,z_tensor,img_pose_tensor,img1_tensor,z1_tensor,img1_pose_tensor

In [5]:
device =torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.conv1  = nn.Conv2d(4,16,3)
        self.conv2  = nn.Conv2d(16,32,5)
        self.conv3  = nn.Conv2d(32,64,10)
        self.conv4  = nn.Conv2d(64,64,12)
        #self.pool = nn.MaxPool2d(2,return_indices=True)
        
        self.fc1 = nn.Linear(1046784,16)
        #self.fc2 = nn.Linear(240,64)
        #self.fc3 = nn.Linear(240,16)
        
        #self.fc_ = nn.Linear(16,16)
        
        #self.fcT1 = nn.Linear(16,240)
        #self.fcT2 = nn.Linear(64,240)
        self.fcT3 = nn.Linear(16,1046784)
        
        
        self.convT1 = nn.ConvTranspose2d(64,64,12)
        self.convT2 = nn.ConvTranspose2d(64,32,10)
        self.convT3 = nn.ConvTranspose2d(32,16,5)
        self.convT4 = nn.ConvTranspose2d(16,3,3)
        #self.unpool = nn.MaxUnpool2d(2)
        
        
    def forward(self,x) :
        
        x = x.permute(0,3, 1, 2)
        x = (F.relu(self.conv1(x)))
        x = (F.relu(self.conv2(x)))
        x = (F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        #x = F.relu(self.fc2(x))
        #x = F.relu(self.fc3(x))
        
        #x = self.fc_(x)
        
        #x = F.relu(self.fcT1(x))
        #x = F.relu(self.fcT2(x))
        x = F.relu(self.fcT3(x))
        x = x.reshape(8,64,174,94)
        x = F.relu(self.convT1(x))
        x = F.relu(self.convT2(x))
        x = F.relu(self.convT3(x))
        x = F.relu(self.convT4(x))
        

        return x.permute(0,2,3,1)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.conv1  = nn.Conv2d(4,16,3)
        self.conv2  = nn.Conv2d(16,32,3)
        self.conv3  = nn.Conv2d(32,48,3)
        self.conv4  = nn.Conv2d(48,64,3)
        self.pool = nn.MaxPool2d(2,return_indices=True)
        
        self.fc1 = nn.Linear(64*21*11,240)
        self.fc2 = nn.Linear(240,64)
        self.fc3 = nn.Linear(64,16)
        
        self.fc_ = nn.Linear(16,16)
        
        self.fcT1 = nn.Linear(16,64)
        self.fcT2 = nn.Linear(64,240)
        self.fcT3 = nn.Linear(240,64*21*11)
        
        
        self.convT1 = nn.ConvTranspose2d(64,48,3)
        self.convT2 = nn.ConvTranspose2d(48,32,4)
        self.convT3 = nn.ConvTranspose2d(32,16,3)
        self.convT4 = nn.ConvTranspose2d(16,4,3)
        self.unpool = nn.MaxUnpool2d(2)
        
        
    def forward(self,x) :
        
        x = x.permute(0,3, 1, 2)
        batch_ = x.shape[0]
        x,x_1 = self.pool(F.relu(self.conv1(x)))
        x,x_2 = self.pool(F.relu(self.conv2(x)))
        x,x_3 = self.pool(F.relu(self.conv3(x)))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc_(x)
        x = F.relu(self.fcT1(x))
        x = F.relu(self.fcT2(x))
        x = F.relu(self.fcT3(x))
        x = x.reshape(batch_,64,21,11)
        #x = torch.concat((x,x1),axis=-1)
        x = self.unpool(F.relu(self.convT1(x)),x_3)
        x = self.unpool(F.relu(self.convT2(x)),x_2)
        #x = self.unpool(F.relu(self.convT3(x)),x_1)
        #x = self.unpool(F.relu(self.convT4(x)),x_3)
       
     

        
        
        return x.shape
        return x.permute(0,2,3,1)

In [None]:
generator = Generator()


In [None]:
generator(noise_vector)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.conv1  = nn.Conv2d(3,64,3, bias=False)
        self.conv2  = nn.Conv2d(64,64*2,5, bias=False)
        self.conv3  = nn.Conv2d(64*2,1,10, bias=False)
        self.pool = nn.MaxPool2d(2,2)
        self.fc1 = nn.Linear(171,128)
        self.fc2 = nn.Linear(128,64)
        self.fc3 = nn.Linear(64,1)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x) :
        x = x.permute(0,3,1,2)
        x = F.relu(self.pool((self.conv1(x))))
        x = F.relu(self.pool((self.conv2(x))))
        x = F.relu(self.pool((self.conv3(x))))
        x = torch.flatten(x,1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = torch.flatten(x,1)
        
        
        return x

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
adv_loss = nn.BCEWithLogitsLoss()

In [None]:
def nn_loss(predicted, ground_truth, nh=3, nw=3):
    v_pad = nh // 2
    h_pad = nw // 2
    val_pad = nn.ConstantPad2d((v_pad, v_pad, h_pad, h_pad), -10000)(ground_truth)

    reference_tensors = []
    for i_begin in range(0, nh):
        
        i_end = i_begin - nh + 1
        i_end = None if i_end == 0 else i_end
        for j_begin in range(0, nw):
            j_end = j_begin - nw + 1
            j_end = None if j_end == 0 else j_end
            sub_tensor = val_pad[:, :, i_begin:i_end, j_begin:j_end]
            reference_tensors.append(sub_tensor.unsqueeze(-1))
    reference = torch.cat(reference_tensors, dim=-1)
    ground_truth = ground_truth.unsqueeze(dim=-1)

    predicted = predicted.unsqueeze(-1)
    abs = torch.abs(reference - predicted)
    norms = torch.sum(abs, dim=1)
    loss,_ = torch.min(norms, dim=-1)
    loss = torch.mean(loss)

    return loss

In [None]:
def gen_loss(fake_img,input_):
    return adv_loss(fake_img,input_)

In [None]:
def disc_loss(fake_img,target_img):
    return adv_loss(fake_img+0.0001,target_img)

In [None]:
learning_rate = 0.0001
G_optimizer = optim.Adam(generator.parameters(), lr = learning_rate)
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate)


In [6]:
dataset = CustomDataset()
data_loader = DataLoader(dataset,batch_size = 8)

In [None]:
num_epochs = 100

In [None]:
for epochs in range(1,num_epochs+1):
    for index, (real_img,real_img_noise,real_img_pose,target_img,target_img_noise,target_img_pose) in enumerate(data_loader):
        D_optimizer.zero_grad()
        
        # Discriminator training ----
        real_img = real_img.float()
        real_img = real_img.to(device)
        
        target_img = target_img.float()
        target_img = target_img.to(device)
        target_img_noise = target_img_noise.to(device)
        
        real_img_noise = real_img_noise.float().to(device)
        real_target = Variable(torch.ones(real_img.size(0)).to(device))
        real_target = real_target.unsqueeze(1)

        fake_target = Variable(torch.zeros(real_img.size(0)).to(device))
        fake_target = fake_target.unsqueeze(1)
        
        output = discriminator(real_img)
        D_real_loss = disc_loss(output,real_target)
        D_real_loss.backward()
        
        
        gen_input =torch.concat((real_img.reshape(8,200,120,3),target_img_noise.reshape(8,200,120,1)),axis=-1).to(device)
        noise_vector = gen_input.to(device)
        generated_img = generator(noise_vector)
        output = discriminator(generated_img.detach())
        D_fake_loss = disc_loss(output,fake_target)
        D_fake_loss.backward()
        D_optimizer.step()
        
        if index%100 == 0:
            print(f'disc_loss={D_fake_loss.item() + D_real_loss.item()}')
            
        # Generator training ----
        G_optimizer.zero_grad()
        gen_output = discriminator(generated_img)
        G_loss = gen_loss(gen_output,real_target)*10 + nn_loss(generated_img,target_img)
        
        if index%100 == 0:
            print(f'gen_loss={G_loss.item()}')
        G_loss.backward()
        G_optimizer.step()
    print(f'{epochs} done')

In [None]:
generated_img[1].shape

In [None]:
disc(imgs.reshape(200,120,3).to(device).float())

In [None]:
import matplotlib.pyplot as plt

In [None]:
x=generated_img[4].detach().to('cpu').reshape(200,120,3)

In [None]:
x=x.numpy()

In [None]:
plt.imshow(x.astype(int))