# Libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
import torch
from torch import nn
from torchvision import datasets, transforms, models
from torch.utils.tensorboard import SummaryWriter
from adan_pytorch import Adan

In [None]:
import os
from datetime import datetime

In [None]:
from dataset import listDataset
from model import CSRNet
from model_nostride import CSRNet_ns
from utils import save_net, load_net, save_checkpoint

# Check for GPU availability

In [None]:
print(f'Is cusa available: {torch.cuda.is_available()}')
print(f'Cuda device name: {torch.cuda.get_device_name(0)}')

In [None]:
device = torch.device('cuda')
device

# Data Loader

In [None]:
path2train = 'ShangaiTech/ShanghaiTech/part_A/train_data/images/'
path2test = 'ShangaiTech/ShanghaiTech/part_A/test_data/images/'

path2den_train = 'ShangaiTech_newdensity/ShangaiTech_newdensity/A/train_data/'
path2den_test = 'ShangaiTech_newdensity/ShangaiTech_newdensity/A/test_data/'

In [None]:
train_loader = torch.utils.data.DataLoader(listDataset([path2train + i for i in os.listdir(path2train)],
                                                       [path2den_train + i.replace('jpg', 'npy').replace('IMG', 'DEN') for i in os.listdir(path2train)],
                                                       shuffle=True,
                                                       transform=transforms.Compose([transforms.ToTensor(),
                                                                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]),
                                                       train=True,
                                                       seen=None,
                                                       batch_size=32,
                                                       num_workers=4))


test_loader = torch.utils.data.DataLoader(listDataset([path2test + i for i in os.listdir(path2test)],
                                                      [path2den_test + i.replace('jpg', 'npy').replace('IMG', 'DEN') for i in os.listdir(path2test)],
                                                      shuffle=True,
                                                      transform=transforms.Compose([transforms.ToTensor(),
                                                                                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),]),
                                                      train=False,
                                                      seen=None,
                                                      batch_size=32,
                                                      num_workers=4))

# Model Initialization

In [None]:
model = CSRNet().to(device)

# Criterion and Optimizer 

In [None]:
criterion = nn.MSELoss()

In [None]:
#Optimizer parameter
lr = 1e-7
momentum      = 0.95
decay         = 5*1e-4

optimizer = torch.optim.SGD(model.parameters(),
                            lr,
                            momentum=momentum,
                            weight_decay=decay)

optimizer = torch.optim.Adam(model.parameters(), lr)

# optimizer = Adan(
#     model.parameters(),
#     lr = 1e-7,                  # learning rate (can be much higher than Adam, up to 5-10x)
#     betas = (0.02, 0.08, 0.01), # beta 1-2-3 as described in paper - author says most sensitive to beta3 tuning
#     weight_decay = 0.02         # weight decay 0.02 is optimal per author
# )

# Training function

In [None]:
def train_one_epoch(train_loader, criterion, optimizer, epoch_index, tb_writer, device):
    running_loss = 0.
    last_loss = 0.

  # Here, we use enumerate(training_loader) instead of
  # iter(training_loader) so that we can track the batch
  # index and do some intra-epoch reporting
    for i, data in enumerate(train_loader):
    #     # Every data instance is an input + label pair
        images, density = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(images.to(device))
        # print('nan in output: ', outputs.isnan().sum())
        # print('nan in density.unsqueeze: ', density.unsqueeze(0).cuda().isnan().sum())

        # if i == 0:
        #     fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10,5))
        #     ax[0].imshow(density.squeeze())
        #     ax[1].imshow(outputs.squeeze().cpu().detach().numpy())
        #     plt.show()
        # Compute the loss and its gradients
        loss = criterion(outputs.squeeze(), density.squeeze().to(device))
        # print('loss', loss)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 10 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.

    return last_loss

# Training

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 100

best_vloss = 1_000_000.

model.train(True)
for epoch in range(EPOCHS):
    model.train(True)
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    avg_loss = train_one_epoch(train_loader=train_loader, 
                               criterion = criterion,
                               optimizer=optimizer,
                               epoch_index=epoch_number,
                               tb_writer=writer,
                               device=device)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(test_loader):
            vinputs, vtarget = vdata
            voutputs = model(vinputs.cuda())
            vloss = criterion(voutputs, vtarget.unsqueeze(0).cuda())
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'models/model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1