In [1]:
import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from torchvision.datasets import FashionMNIST,StanfordCars
from matplotlib import pyplot as plt
import numpy as np
import torch.nn.functional as F
import math
import pandas as pd
from PIL import Image

In [2]:
import wandb

In [3]:
?wandb.init

In [4]:
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [5]:
# Defining the device

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [1]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3)
        self.relu  = nn.ReLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3)
    
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))


class Encoder(nn.Module):
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
        self.pool       = nn.MaxPool2d(2)
    
    def forward(self, x):
        ftrs = []
        for block in self.enc_blocks:
            x = block(x)
            ftrs.append(x)
            x = self.pool(x)
        return ftrs


class Decoder(nn.Module):
    def __init__(self, chs=(1024, 512, 256, 128, 64)):
        super().__init__()
        self.chs         = chs
        self.upconvs    = nn.ModuleList([nn.ConvTranspose2d(chs[i], chs[i+1], 2, 2) for i in range(len(chs)-1)])
        self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)]) 
        
    def forward(self, x, encoder_features):
        for i in range(len(self.chs)-1):
            x        = self.upconvs[i](x)
            enc_ftrs = self.crop(encoder_features[i], x)
            x        = torch.cat([x, enc_ftrs], dim=1)
            x        = self.dec_blocks[i](x)
        return x
    
    def crop(self, enc_ftrs, x):
        _, _, H, W = x.shape
        enc_ftrs   = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
        return enc_ftrs


class UNet(nn.Module):
    def __init__(self, enc_chs=(3,64,128,256,512,1024), dec_chs=(1024, 512, 256, 128, 64), num_class=1, retain_dim=False, out_sz=(572,572)):
        super().__init__()
        self.encoder     = Encoder(enc_chs)
        self.decoder     = Decoder(dec_chs)
        self.head        = nn.Conv2d(dec_chs[-1], num_class, 1)
        self.retain_dim  = retain_dim

    def forward(self, x):
        enc_ftrs = self.encoder(x)
        out      = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
        out      = self.head(out)
        if self.retain_dim:
            out = F.interpolate(out, out_sz)
        return out

In [11]:
# dims=1024
# model = VAE(dims).to(device)

In [12]:

# Setting the optimiser

learning_rate = 1e-3

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
)

In [13]:
# Reconstruction + KL divergence losses summed over all elements and batch

def loss_function(ỹ, y):
    BCE = nn.functional.binary_cross_entropy(
        ỹ, y, reduction='sum'
    )
#     KLD = (-0.5 * torch.mean(-logvar1.exp() + logvar1 + 1.0 - mu1.pow(2)))
    return BCE

In [14]:
class MyDataset(Dataset):
    def __init__(self, train_path,transform_x=None,transform_y=None):
        self.df = pd.read_csv(train_path, sep=',', usecols=['input', 'output'])
        self.transform_x=transform_x
        self.transform_y=transform_y
    def __getitem__(self, index):
#         print(self.df.iloc[index, 1])
#         print(self.df.iloc[index, 0])
        x = np.array(Image.open(self.df.iloc[index, 1]))
        y = np.array(Image.open(self.df.iloc[index, 0]))
        if self.transform_x is not None:
            x=self.transform_x(x)
            y=self.transform_y(y)
        else:
            x, y = torch.from_numpy(x), torch.from_numpy(y)
        return x, y
    def __len__(self):
#         return len(self.df)
        return 3000

In [15]:
epochs = 1000
batch_size = 16

In [16]:
wandb.config = {
  "learning_rate": learning_rate,
  "epochs": epochs,
  "batch_size": batch_size,
  "dims":dims
}

In [17]:
wandb.init(project="AerialPoseEstimator")

[34m[1mwandb[0m: Currently logged in as: [33mpthpth[0m. Use [1m`wandb login --relogin`[0m to force relogin
wandb: ERROR Failed to sample metric: Not Supported


In [18]:
train_loader=MyDataset("./dataset_train.csv")
test_loader=MyDataset("./dataset_test.csv")
train_loader=DataLoader(train_loader, batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_loader, batch_size=batch_size,shuffle=True)

In [19]:
wandb.watch(model)

[]

In [20]:
def batch_mean_x(loader):
    cnt=0
    fst_moment=torch.empty(3)
    snd_moment=torch.empty(3)
    for images,_ in loader:
        # c h w b
#         print(images.shape)
        images=images/255
        b,h,w,c = images.shape
        nb_pixels=b * h * w
        sum_ =  torch.sum(images,dim=[0,1,2])
        sum_of_square = torch.sum(images**2,dim=[0,1,2])
        
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / ( cnt + nb_pixels)
        
        cnt+=nb_pixels
    mean,std=fst_moment,torch.sqrt(snd_moment - fst_moment ** 2)
    return mean,std

In [21]:
def batch_mean_y(loader):
    cnt=0
    fst_moment=torch.empty(4)
    snd_moment=torch.empty(4)
    for _,images in loader:
        # c h w b
#         print(images.shape)
        images=images/255
        b,h,w,c = images.shape
        nb_pixels=b * h * w
        sum_ =  torch.sum(images,dim=[0,1,2])
        sum_of_square = torch.sum(images**2,dim=[0,1,2])
        
        fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / ( cnt + nb_pixels)
        
        cnt+=nb_pixels
    mean,std=fst_moment,torch.sqrt(snd_moment - fst_moment ** 2)
    return mean,std

In [22]:
mean_x,std_x=batch_mean_x(train_loader)
mean_y,std_y=batch_mean_y(train_loader)

transform_img_normal_x = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = mean_x,std= std_x)
])
transform_img_normal_y = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean = 0,std = 1)
])
train_loader=MyDataset("./dataset_train.csv",
                       transform_x=transform_img_normal_x,
                       transform_y=transform_img_normal_y)
test_loader=MyDataset("./dataset_test.csv",
                      transform_x=transform_img_normal_x,
                      transform_y=transform_img_normal_y)
train_loader=DataLoader(train_loader, batch_size=batch_size,shuffle=True)
test_loader=DataLoader(test_loader, batch_size=batch_size,shuffle=True)

In [None]:
# Training and testing the VAE

codes = dict(μ=list(), logσ2=list(), x=list())
for epoch in range(0, epochs + 1):
    # Training
    if epoch > 0:  # test untrained net first
        model.train()
        train_loss = 0
        bcs=0
        kls=0
        for x,y in train_loader:
            x = x.to(device)
            y = y.to(device)
            x=x.view(-1,3,480,720)
            y=y.view(-1,4,480,720)
            x=torch.div(x,255)
            y=torch.div(y,255)
            xs=[]
            ys=[]
            for i in x.split(360,-1):
                for j in i.split(240,-2):
                    xs.append(j)
            for i in y.split(360,-1):
                for j in i.split(240,-2):
                    ys.append(j)
            x=torch.cat(xs)
            y=torch.cat(ys)
            # ===================forward=====================
#             y_bar, mu1, logvar1 = model(x)
#             loss,bc,kl = loss_function(y_bar, y, mu1, logvar1)
            y_bar=model(x)
            loss = loss_function(y_bar, y)
            train_loss += loss
#             bcs+=bc.item()
#             kls+=kl
            # ===================backward====================
            optimizer.zero_grad()
            loss.backward()
#             print(loss.item())
            optimizer.step()
        # ===================log========================

    # Testing
        wandb.log({"train_loss":train_loss /len(train_loader.dataset), 
                       "train BCE Loss": bcs / len(train_loader.dataset),
                       "train KLD Loss": kls / len(train_loader.dataset)})
        means, logvars, labels = list(), list(), list()
        if epoch%10==0:
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': train_loss,}, 
                       "./Weights/resnet.pt")
    torch.cuda.empty_cache()
    with torch.no_grad():
        model.eval()
        test_loss = 0
        bcs=0
        kls=0
        for x,y in test_loader:
            x = x.to(device)
            y = y.to(device)
            x=x.view(-1,3,480,720)
            y=y.view(-1,4,480,720)
            x=torch.div(x,255)
            y=torch.div(y,255)
            xs=[]
            ys=[]
            for i in x.split(360,-1):
                for j in i.split(240,-2):
                    xs.append(j)
            for i in y.split(360,-1):
                for j in i.split(240,-2):
                    ys.append(j)
            x=torch.cat(xs)
            y=torch.cat(ys)
            # ===================forward=====================
#             ỹ, mu, logvar = model(x)
            ỹ = model(x)
            # print(ỹ.shape)
            # print(y.shape)
#             loss,bc,kl = loss_function(ỹ, y, mu, logvar)
            loss = loss_function(ỹ, y)
            test_loss+=loss.item()
#             bcs+=bc.item()
#             kls+=kl.item()
    test_loss /= len(test_loader.dataset)
    wandb.log({"test_loss":test_loss /len(test_loader.dataset), 
                   "test BCE Loss": bcs / len(test_loader.dataset),
                   "test KLD Loss": kls / len(test_loader.dataset)})
    print(epoch)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21


In [None]:
# test_loader=MyDataset("./dataset_test.csv")

In [None]:
# temp =np.array(Image.open("./Datasets/Input/Echendens-LHS_09620.png_6.png"), dtype = float)/255.0

In [None]:
# tem = torch.from_numpy(temp).view(-1,3,480,720)

In [None]:
# tem=tem.to(device,dtype=torch.float32)

In [None]:
# ans=(model(tem))

In [None]:
# ans=(ans[0]*255).detach().cpu().numpy()

In [None]:
# ans.shape

In [None]:
# img=Image.frtomarray(ans)