In [None]:
import gc
from time import time, sleep

import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision.models.segmentation as models

from semantic_segmentation.dataset import CustomDataset, ToTensorNormalize
from stats import discord_bot

In [2]:
from models.ss import calculate_miou

In [3]:
data = torch.load("./models/SSDataset.pt")

In [4]:
tqdm_miniters = 5
torch.manual_seed(2050808)
num_epochs = 20
num_steps_end_decay = 69773
batch_size = 32
lr_start = 1e-4
lr_gamma = 0.999967
model = models.deeplabv3_resnet101(weights=models.DeepLabV3_ResNet101_Weights.DEFAULT)
(device := torch.device("cuda" if torch.cuda.is_available() else "cpu"))

device(type='cuda')

In [5]:
webhook = discord_bot(extra=f"DeepLabV3-ResNet101(lr={lr_start}; {lr_gamma}; {num_steps_end_decay})")

In [6]:
# Create dataset and data loader
dataset = CustomDataset(data=data, transform=ToTensorNormalize(device=device))
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [7]:
eval_frequency = len(train_loader) // 5

In [9]:
prev = model.classifier[-1]
model.classifier[-1] = nn.Conv2d(prev.in_channels, 13, kernel_size=prev.kernel_size, stride=prev.stride)
del prev
gc.collect()

32

In [10]:
model = torch.jit.script(model.to(device))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr_start)
optimizer.param_groups[0]["initial_lr"] = lr_start
# scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-4, end_factor=1e-5, total_iters=num_steps_end_decay)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=lr_gamma, last_epoch=num_steps_end_decay)



In [None]:
miou = []
batch_loss = []
timestep = []
epoch_loss = []

train_start_time = time()
for epoch in range(num_epochs):
    pbar = tqdm(total=len(train_loader), desc=f"Epoch {epoch}", miniters=tqdm_miniters)
    model.train()
    running_loss = 0.0
    batch_count = 0

    for images, masks in train_loader:
        optimizer.zero_grad()

        outputs = model(images)['out']
        loss = criterion(outputs, masks)

        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * images.size(0)
        batch_count += 1
        pbar.update()

        if batch_count % eval_frequency == 0:
            model.eval()
            iou_scores = []
            t = time()
            
            with torch.no_grad():
                for test_images, test_masks in tqdm(test_loader, desc="Eval", leave=False, miniters=int(tqdm_miniters*2.5)):
                    test_outputs = model(test_images)['out']
                    iou = calculate_miou(test_outputs, test_masks, num_classes=13)
                    iou_scores.append(iou)

            mean_iou = np.mean(iou_scores)
            curr_time = time()
            eval_time = round(curr_time - t)
            t_loss = running_loss / (batch_count * batch_size)

            batch_loss.append(t_loss)
            miou.append(mean_iou)
            timestep.append([epoch, batch_count, round(curr_time - train_start_time)])

            msg = f"Epoch {epoch+1}, Batch {batch_count}, Mean IoU: {mean_iou:.4f}, Training Loss: {t_loss:.4f}\nLR: {scheduler.get_last_lr()[0]:.7f}, Eval time: {eval_time // 60:2d} min {eval_time % 60:2d} sec"
            webhook.send_msg(msg)
            print(msg, flush=True)
            
            model.train()
            
    pbar.close()
    plt.clf()
    plt.plot(miou)
    plt.title("Mean IoU")
    # plt.ylim(0, 1)
    plt.savefig(f"{webhook.path}/current.png")
    plt.show()
    webhook.send_img(epoch)
    current_state = {
        "optimizer" : optimizer.state_dict(),
        "scheduler" : scheduler.state_dict(),
        "model"     : model.state_dict()
    }
    torch.save(current_state, f"semantic_segmentation/ss_mobilenet_{epoch:02d}.pt")
    np.savez("semantic_segmentation/SS_Results.npz", 
             miou=np.array(miou), 
             batch_loss=np.array(batch_loss), 
             epoch_loss=np.array(epoch_loss), 
             timestep=np.array(timestep, dtype=np.uint64))
    e_loss = running_loss / train_size
    epoch_loss.append([e_loss, running_loss])
    train_time = round(time() - train_start_time)
    hr = train_time // 3600
    train_time %= 3600
    msg = f"Epoch {epoch+1}/{num_epochs}, Training Loss: {e_loss:.4f}, Training Time: {hr:3d} hr {train_time // 60:2d} min {train_time % 60:2d} sec"
    webhook.send_msg(msg)
    print(msg, flush=True)
    sleep(10)

print('Training complete.')