In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from torchvision.datasets import FashionMNIST,StanfordCars
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

In [2]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
# Defining the device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [4]:
# Defining the model

d = 256

class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder1 = nn.Sequential(
            nn.Conv2d(3,4,4),
            nn.ReLU(),
            nn.Conv2d(4,1,4),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(84609,d*2)
        )

        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(1,8,4,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(8,4,4,stride=1,padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(4,4,4,stride=1),
            nn.Sigmoid()
        )
        self.transfer=nn.Sequential(
            nn.Linear(d,84609),
            nn.Unflatten(1,(237,357))
        )
        
        
    def reparameterise(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = std.data.new(std.size()).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

    def forward(self, y):
        mu_logvar = self.encoder1(y)
#         print(mu_logvar.shape)
#         z=0
#         mu1=0
#         logvar1=0
        mu_logvar=mu_logvar.view(-1,2,d)
        mu1= mu_logvar[:, 0, :]
        logvar1 = mu_logvar[:, 1, :]
        z = self.reparameterise(mu1, logvar1).view(-1,d)
#         print(z.shape)
        z = self.transfer(z).view(-1,1,237,357)
        z=self.decoder1(z)
#         print("output=",z.shape)
        return z,mu1,logvar1

In [5]:
model = VAE().to(device)

In [6]:
tester=torch.ones((2,3,720,480))
tester=tester.to(device)
print(model(tester)[0].shape)

torch.Size([2, 4, 480, 720])


In [7]:

# Setting the optimiser

learning_rate = 1e-3

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

In [8]:
# Reconstruction + KL divergence losses summed over all elements and batch

def loss_function(ỹ, y, mu1, logvar1):
    BCE = nn.functional.binary_cross_entropy(
        ỹ, y, reduction='sum'
    )
    KLD = (-0.5 * torch.sum(-logvar1.exp() + logvar1 + 1.0 - mu1.pow(2)))
    return BCE + 10*KLD,BCE,KLD

In [9]:
class MyDataset(Dataset):
    def __init__(self, train_path):
        self.df = pd.read_csv(train_path, sep=',', usecols=['input', 'output'])
#         print(self.df.shape)
    def __getitem__(self, index):
        x = np.array(Image.open(self.df.iloc[index, 1]))
        y = np.array(Image.open(self.df.iloc[index, 0]))
  
        x, y = torch.from_numpy(x), torch.from_numpy(y)
        return x, y

    def __len__(self):
        return 2000

In [10]:
train_loader=MyDataset("./dataset_train.csv")
train_loader=DataLoader(train_loader, batch_size=16,shuffle=True)


train_loader

In [13]:
# Training and testing the VAE
do=nn.Dropout()
epochs = 100
codes = dict(μ=list(), logσ2=list(), x=list())
for epoch in range(0, epochs + 1):
    # Training
    if epoch > 0:  # test untrained net first
        model.train()
        train_loss = 0
        bcs=0
        kls=0
        for x,y in train_loader:
            x = x.to(device)
            y = y.to(device)
            x=x.view(-1,3,480,720)
            y=y.view(-1,4,480,720)
            x=torch.div(x,255)
            y=torch.div(y,255)
            # ===================forward=====================
            y_bar, mu1, logvar1 = model(x)
            loss,bc,kl = loss_function(y_bar, y, mu1, logvar1)
            train_loss += loss
            bcs+=bc.item()
            kls+=kl
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
#             print(loss.item())
            optimizer.step()
        # ===================log========================
        print(f'====> Epoch: {epoch} Average loss: {train_loss /20000} BCE Loss: {bcs / 2000} KLD Loss: {kls / 2000}')
        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, "./Weights/basic_net.pt")
        
    # Testing
    
#     means, logvars, labels = list(), list(), list()
#     with torch.no_grad():
#         model.eval()
#         test_loss = 0
#         bcs=0
#         kls=0
#         for y, x in test_loader:
#             y = y.to(device)
#             y=y.view(-1,1,28,28)
#             # ===================forward=====================
#             ỹ, mu1, logvar1,mu2,logvar2,mu3,logvar3 = model(y)
#             # print(ỹ.shape)
#             # print(y.shape)
#             loss,bc,kl = loss_function(ỹ, y, mu1, logvar1,mu2,logvar2,mu3,logvar3)
#             test_loss+=loss.item()
#             bcs+=bc.item()
#             kls+=kl.item()
#             # =====================log=======================
#             means.append(mu3.detach())
#             logvars.append(logvar3.detach())
#             labels.append(x.detach())
#     # ===================log========================
#     codes['μ'].append(torch.cat(means))
#     codes['logσ2'].append(torch.cat(logvars))
#     codes['x'].append(torch.cat(labels))
#     test_loss /= len(test_loader.dataset)
#     print(f'Average loss: {test_loss} BCE Loss: {bcs / len(test_loader.dataset):.4f} KLD Loss: {kls / len(test_loader.dataset):.4f}')

====> Epoch: 1 Average loss: 75157.4765625 BCE Loss: 742630.112 KLD Loss: 894.4623413085938
====> Epoch: 2 Average loss: 69840.9296875 BCE Loss: 695016.7315 KLD Loss: 339.2535400390625
====> Epoch: 3 Average loss: 69560.8984375 BCE Loss: 693388.832 KLD Loss: 222.0042266845703
====> Epoch: 4 Average loss: 69548.15625 BCE Loss: 693457.415 KLD Loss: 202.4368896484375
====> Epoch: 5 Average loss: 69534.9453125 BCE Loss: 692944.747 KLD Loss: 240.4820556640625
====> Epoch: 6 Average loss: 69792.984375 BCE Loss: 694602.0185 KLD Loss: 332.7666931152344
====> Epoch: 7 Average loss: 69687.3984375 BCE Loss: 692864.651 KLD Loss: 400.9363098144531
====> Epoch: 8 Average loss: 69691.3515625 BCE Loss: 693236.9285 KLD Loss: 367.6723937988281
====> Epoch: 9 Average loss: 69614.15625 BCE Loss: 692560.52 KLD Loss: 358.0784606933594
====> Epoch: 10 Average loss: 69525.0859375 BCE Loss: 692679.9815 KLD Loss: 257.0904235839844
====> Epoch: 11 Average loss: 69575.6640625 BCE Loss: 693318.715 KLD Loss: 243.79

KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

In [None]:
test_loader=MyDataset("./dataset_test.csv")

In [31]:
temp =np.array(Image.open("./Datasets/Input/Echendens-LHS_09620.png_6.png"), dtype = float)/255.0

In [43]:
tem = torch.from_numpy(temp).view(-1,3,480,720)

In [46]:
tem=tem.to(device,dtype=torch.float32)

In [61]:
ans=(model(tem))

In [62]:
ans=(ans[0]*255).detach().cpu().numpy()

In [67]:
ans.shape

(1, 4, 480, 720)

In [66]:
img=Image.fromarray(ans)

TypeError: Cannot handle this data type: (1, 1, 480, 720), <f4