# Imports

In [None]:
import os,cv2,torch
import numpy as np
from torch.utils.data import DataLoader
import torchvision
import pickle
import torch.nn as nn
import torchvision
import torch
import numpy as np
import math
from torch import Tensor
from functools import partial

from utils.args import *
from utils.arch import *
from utils.helper import *

# Dataset path and Pickle files created by Dataset_setup notebook

In [None]:
train_path = 'data/train/'
val_path = 'data/val/'

with open('train_dict.pkl', 'rb') as f:
    train_dict = pickle.load(f)

with open('val_dict.pkl', 'rb') as f:
    val_dict = pickle.load(f)

# Train and Val datasets

In [None]:
train_dataset = {'tall':None,'square':None,'wide':None}
val_dataset = {'tall':None,'square':None,'wide':None}

for res in train_dataset.keys():
    train_dataset[res] = [(os.path.join(train_path,x[:-3]+'jpg'),os.path.join(train_path,x)) for x in train_dict[res]]
    
for res in val_dataset.keys():
    val_dataset[res] = [(os.path.join(val_path,x[:-3]+'jpg'),os.path.join(val_path,x)) for x in val_dict[res]]

# Train and Val dataloaders

In [None]:
train_loaders = {}

train_loaders['tall'] = DataLoader(train_dataset['tall'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=tall_res[1], h=tall_res[0]))
train_loaders['square'] = DataLoader(train_dataset['square'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=square_res[1], h=square_res[0]))
train_loaders['wide'] = DataLoader(train_dataset['wide'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=wide_res[1], h=wide_res[0]))

val_loaders = {}

val_loaders['tall'] = DataLoader(val_dataset['tall'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=tall_res[1], h=tall_res[0]))
val_loaders['square'] = DataLoader(val_dataset['square'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=square_res[1], h=square_res[0]))
val_loaders['wide'] = DataLoader(val_dataset['wide'], batch_size=BATCH_SIZE,shuffle=True, collate_fn=partial(generate_batch, w=wide_res[1], h=wide_res[0]))


# Train and val dataloader length info

In [None]:
train_lens = {}
for res in train_loaders.keys():
    train_lens[res] = len(train_loaders[res])
    
val_lens = {}
for res in val_loaders.keys():
    val_lens[res] = len(val_loaders[res])

# Network, loss function and optimizer initialize

In [None]:
net = Network(DEVICE).to(DEVICE)

## Code to load intermediate weights and continue training
# model_path = '<model_checkpoint>.pth'
# dic=torch.load(model_path,map_location=torch.device(DEVICE))
# net.load_state_dict(dic)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=1)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)


# Image augmentations

In [None]:
T=torchvision.transforms.RandomChoice([
    torchvision.transforms.ColorJitter(0.85,0.85,0.85,0.5),
    torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5))
    ])

# Initial iteration to figure out max batch size that can be utilized

In [None]:
imgs,tgt_pad,targets = generate_batch(train_dataset['wide'][11*BATCH_SIZE:12*BATCH_SIZE],w=wide_res[1], h=wide_res[0])
imgs = imgs.to(DEVICE)
targets = targets.to(DEVICE)
tgt_pad = tgt_pad.to(DEVICE)

imgs = torch.stack([T(x) for x in imgs])

logits = net(imgs/255,targets[:-1,:],tgt_pad)

targets = targets[1:].reshape(-1)
logits = logits.reshape(-1, logits.shape[-1])

loss = loss_fn(logits,targets)
loss.backward()
optimizer.step()

# Main training loop

In [None]:
min_val_loss = 9999

for epoch in range(EPOCH):
    
    losses = 0.0
    val_losses = 0.0

    print("Starting train for ", epoch)
    
    net = net.train()
    for p in net.parameters():
        p.requires_grad = True
    
    done_iters = {'tall':False,'square':False,'wide':False}
    
    train_iters = {}
    for res in train_loaders.keys():
        train_iters[res] = enumerate(train_loaders[res])
    
    while not all(done_iters.values()):
        
        optimizer.zero_grad()
        
        sel = np.random.choice(['tall','square','wide'],p=resolution_dist)
        while done_iters[sel]:
            sel = np.random.choice(['tall','square','wide'],p=resolution_dist)
        
        i,(imgs,tgt_pad,targets) = next(train_iters[sel])
        
        if i == train_lens[sel]-1:
            done_iters[sel] = True
            
        imgs = imgs.to(DEVICE)
        targets = targets.to(DEVICE)
        tgt_pad = tgt_pad.to(DEVICE)
        
        imgs = torch.stack([T(x) for x in imgs])
        
        logits = net(imgs/255,targets[:-1,:],tgt_pad)
        
        targets = targets[1:].reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits,targets)
        loss.backward()
        optimizer.step()
        
        losses += loss.item()
        
        
        

    print("Starting val for ", epoch)
    
    net = net.eval()
    for p in net.parameters():
        p.requires_grad = False
        
    done_iters = {'tall':False,'square':False,'wide':False}
    
    val_iters = {}
    for res in val_loaders.keys():
        val_iters[res] = enumerate(val_loaders[res])
    
    while not all(done_iters.values()):
        
        sel = np.random.choice(['tall','square','wide'],p=resolution_dist)
        while done_iters[sel]:
            sel = np.random.choice(['tall','square','wide'],p=resolution_dist)
        
        i,(imgs,tgt_pad,targets) = next(val_iters[sel])
        
        if i == val_lens[sel]-1:
            done_iters[sel] = True

        imgs = imgs.to(DEVICE)
        targets = targets.to(DEVICE)
        tgt_pad = tgt_pad.to(DEVICE)
        
        logits = net(imgs/255,targets[:-1,:],tgt_pad)
        
        targets = targets[1:].reshape(-1)
        logits = logits.reshape(-1, logits.shape[-1])
        
        loss = loss_fn(logits,targets)
        val_losses += loss.item()
        
    print("Loss for epoch ",epoch," = ", losses/sum(train_lens.values()) , ' and ', val_losses/sum(val_lens.values()))
    
    if val_losses/sum(val_lens.values()) <= min_val_loss:
        min_val_loss = val_losses/sum(val_lens.values())
        torch.save(net.state_dict(), file_dest+'model_'+str(epoch)+'.pth')

