# Trying to train UNet with L1 loss

In [None]:
import torch
import torch.utils.data as data
from data.dataset import *
from models.generator import *
from utils.images import *
from models.trainer import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
dataset_train = CocoLab('/datasets/coco', version="2014", size=256, train=True)
trainloader = data.DataLoader(dataset_train, batch_size=64, shuffle=True, num_workers=2)

dataset_test = CocoLab('/datasets/coco', version="2014", size=256, train=False)
testloader = data.DataLoader(dataset_test, batch_size=64, shuffle=True, num_workers=2)

generator = UNet(1, 2).to(device)

In [None]:
# one image from the training set
L_base, ab_base = next(iter(trainloader))

Lab = torch.concat((L_base, ab_base), 1)
tensor_to_pil(Lab)[0]

In [None]:
# trying the prediction of the Unet before training
generator.eval()
L_base = L_base.to(device)
ab_pred_notrain = generator(L_base).detach()
Lab_pred_notrain = torch.concat((L_base, ab_pred_notrain), 1)
tensor_to_pil(Lab_pred_notrain)[0]

In [None]:
num_epochs = 500
display_every = 10
# train_avg_loss, test_avg_loss = train_G_L1(epochs, generator, loader_train, loader_test)

In [None]:
LEARNING_RATE = 0.01

criterion = nn.L1Loss()
optimizer = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE)

train_avg_loss = []
test_avg_loss = []



for i in range(num_epochs):
    train_losses = []
    test_losses = []
    
    generator.train()
    for L, ab in trainloader:
        L = L.to(device)
        ab = ab.to(device)

        pred = generator(L)
        loss = criterion(pred, ab)

        train_losses.append(loss.detach())
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    with torch.no_grad():   
        generator.eval()
        total = 0

        for L, ab in testloader:
            L = L.to(device)
            ab = ab.to(device)
            
            pred = generator(L)
            loss = criterion(pred, ab)
            test_losses.append(loss.detach())

            total += len(pred)

        print(total)

        train_avg_loss.append(torch.mean(torch.Tensor(train_losses)))
        test_avg_loss.append(torch.mean(torch.Tensor(test_losses)))

        print('[Epoch {}/{}] '.format(i+1, num_epochs) +
                'train_loss: {:.4f} - '.format(train_avg_loss[-1]) +
                'test_loss: {:.4f}'.format(test_avg_loss[-1]))


        if i % display_every:
            generator.eval()
            ab_pred = generator(L_base).detach()
            Lab_pred = torch.concat((L_base, ab_pred), 1)
            tensor_to_pil(Lab_pred)[0]

In [None]:
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.title('Losses')
plt.plot(train_avg_loss)
plt.plot(test_avg_loss)
plt.grid()
plt.legend(['Train', 'Test'])
plt.xlabel('Epoch')
plt.ylabel('Loss (L1)')


plt.show()

In [None]:
# trying the prediction of the Unet after training
generator.eval()
ab_pred_train = generator(L).detach()
Lab_pred_train = torch.concat((L, ab_pred_train), 1)

In [None]:
tensor_to_pil(Lab_pred_train)[0]