In [1]:
import torch
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torch.nn.functional import pad
from Dataset import Dataset, Dataset_warped
from Example import FeatureExtraction, ProcessImage, Model
import cv2
import matplotlib.pyplot as plt

In [2]:
BATCH = 1
EPOCHS = 1
LR = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
featExtrac = FeatureExtraction(3).to(device)
model = Model().to(device)
level1 = ProcessImage(96,3).to(device)
level2 = ProcessImage(192,64).to(device)
level3 = ProcessImage(384,128).to(device)
level4 = ProcessImage(512,256).to(device)
dir_dataset = '../dataset2'
dataset = Dataset(dir = dir_dataset, transform=transforms.ToTensor())
optimizer1 = optim.Adam(level1.parameters(), lr=LR)
optimizer2 = optim.Adam(level2.parameters(), lr=LR)
optimizer3 = optim.Adam(level3.parameters(), lr=LR)
optimizer4 = optim.Adam(level4.parameters(), lr=LR)

In [3]:
length = dataset.__len__()
test = length//10
train = length - test
trainset, testset = random_split(dataset,[train,test])
trainset = DataLoader(trainset, batch_size=BATCH, shuffle=True, pin_memory=True,num_workers=2)
testset = DataLoader(testset, batch_size=1, shuffle=True)

In [4]:
lossFunction1 = torch.nn.SmoothL1Loss()
lossFunction2 = torch.nn.SmoothL1Loss()
lossFunction3 = torch.nn.SmoothL1Loss()
lossFunction4 = torch.nn.SmoothL1Loss()

for epoch in range(EPOCHS):
    for data in trainset:
        (F1,F2,F3) = data 
        b,c,h,w = F1.shape
        if h%8>0 or w%8>0:
            F1 = pad(F1, (0,0,h%8,w%8), "constant", 0)
            F2 = pad(F2, (0,0,h%8,w%8), "constant", 0)
            F3 = pad(F3, (0,0,h%8,w%8), "constant", 0)
        with torch.no_grad():
            features1 = featExtrac(F1.to(device))
            features2 = featExtrac(F2.to(device))
            features3 = featExtrac(F3.to(device))
        input1= torch.cat([features1[0].to(device), features3[0].to(device)], dim=1)
        input2= torch.cat([features1[1].to(device), features3[1].to(device)], dim=1)
        input3= torch.cat([features1[2].to(device), features3[2].to(device)], dim=1)
        input4= torch.cat([features1[3].to(device), features3[3].to(device)], dim=1)
        out4, up4 = level4(input4)
        out3, up3 = level3(torch.cat([input3, up4], dim=1))
        out2, up2 = level2(torch.cat([input2, up3], dim=1))
        out1, up1 = level1(torch.cat([input1, up2], dim=1))
        loss4  = lossFunction4(out4 , features2[3])
        loss3  = lossFunction3(out3 , features2[2])
        loss2  = lossFunction2(out2 , features2[1])
        loss1  = lossFunction1(out1 , F2.to(device))
        loss4.backward()
        loss3.backward()
        loss2.backward()
        loss1.backward()
        optimizer1.step()
        optimizer1.zero_grad()
        optimizer2.step()
        optimizer2.zero_grad()
        optimizer3.step()
        optimizer3.zero_grad()
        optimizer4.step()
        optimizer4.zero_grad()
        break

KeyboardInterrupt: 

In [None]:
data = next(iter(testset))
(F1,F2,F3) = data
with torch.no_grad():
    output = model(F1.to(device), F3.to(device))
img1 = F1.numpy()[0].transpose(1,2,0)
img2 = F2.cpu().numpy()[0].transpose(1,2,0)
img3 = F3.numpy()[0].transpose(1,2,0)
generated = output.cpu().detach().numpy()[0].transpose(1,2,0)
del output, data

In [None]:
fNew = cv2.cvtColor(generated, cv2.COLOR_BGR2RGB)
f1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
f2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
f3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
fig = plt.figure(figsize=(10, 8))
fig.add_subplot(2, 2, 1) 
plt.imshow(f1)
plt.title("First")
fig.add_subplot(2, 2, 2)
plt.imshow(f2)
plt.title("Second")
fig.add_subplot(2, 2, 3)
plt.imshow(f3)
plt.title("Third")
fig.add_subplot(2, 2, 4)
plt.imshow(fNew)
plt.title("generated")