In [6]:
import numpy as np
from tqdm import tqdm
import os, sys

import torch
import torch.nn as nn
import segmentation_models_pytorch as smp

from dataset import get_dataloaders
from utils import AverageMeter, ConfusionMeter, Metric, Recorder, get_bool
from config import get_config

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LAMBDA = 100

def train_dis(train_loader, model, model_dis, criterion, criterion_dis, criterion_l1, optimizer, optimizer_dis, lr_scheduler):
    loss_meter = AverageMeter('train loss')
    acc_meter = AverageMeter('train accuracy')
    miou_meter = AverageMeter('train mIOU')
    conf_meter = ConfusionMeter(cfg.NUM_CLASS)
    model.train()
    model_dis.train()

    for i, (images, targets, _) in enumerate(tqdm(train_loader)):
        # image is input raw image, target is the real results
        images = images.to(DEVICE)
        targets = targets.long().to(DEVICE)

        # generate fake images
        outputs = model(images)
        # now have both real and fake results, concatenate these and pass into model_dis

        # update discriminiator net with real results
        # TODO: how to concatenate this? should this targets have 19 channels?
        model_dis_input_real = torch.cat((images, targets.unsqueeze(1).repeat(1, 19, 1, 1)), 1)
        dis_outputs_real = model_dis(model_dis_input_real)
        # real y, set to one
        y_real = torch.ones(dis_outputs_real.shape)
        # print("Dis net output shape", dis_outputs_real.shape)
        loss_dis_real = criterion_dis(dis_outputs_real, y_real)

        # update discriminator with fake results
        model_dis_input_fake = torch.cat((images, outputs), 1)
        dis_outputs_fake = model_dis(model_dis_input_fake)
        # fake y, set to zero
        y_fake = torch.zeros(dis_outputs_fake.shape)
        loss_dis_fake = criterion_dis(dis_outputs_fake, y_fake)

        # sum the two loss
        loss_dis = (loss_dis_real + loss_dis_fake)/2
        optimizer_dis.zero_grad()
        loss_dis.backward(retain_graph=True)
        optimizer_dis.step()

        # update generator weight
        loss = criterion(outputs, targets)
        # TODO: should include l1 loss?
        loss_l1 = criterion_l1(outputs, targets.unsqueeze(1))
        loss_gen = loss + loss_l1*LAMBDA

        loss_meter.update(loss.item(), images.size(0))
        conf_meter.update(outputs.argmax(1), targets)
        metric = Metric(conf_meter.value())
        acc_meter.update(metric.accuracy())
        miou_meter.update(metric.miou())

        optimizer.zero_grad()
        loss_gen.backward(retain_graph=True)
        optimizer.step()
        lr_scheduler.step()

    return loss_meter.avg, acc_meter.avg, miou_meter.avg


def validate_dis(val_loader, model, model_dis, criterion):
    loss_meter = AverageMeter('validation loss')
    acc_meter = AverageMeter('validation accuracy')
    miou_meter = AverageMeter('validation mIOU')
    conf_meter = ConfusionMeter(cfg.NUM_CLASS)

    with torch.no_grad():
        model.eval()
        for i, (images, targets, _) in enumerate(tqdm(val_loader)):
            images = images.to(DEVICE)
            targets = targets.long().to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, targets)

            loss_meter.update(loss.item(), images.size(0))
            conf_meter.update(outputs.argmax(1), targets)

            metric = Metric(conf_meter.value())
            acc_meter.update(metric.accuracy())
            miou_meter.update(metric.miou())

    return loss_meter.avg, acc_meter.avg, miou_meter.avg

class Discriminator_Net(nn.Module):
    def __init__(self):
        super(Discriminator_Net, self).__init__()
        self.net = nn.Sequential(
            # C64
            nn.Conv2d(22, 64, kernel_size=(4, 4), stride=(2, 2), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # C128
            nn.Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # C256
            nn.Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 512
            nn.Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # The last layer, no batchnorm, 
            nn.Conv2d(512, 19, kernel_size=(4, 4), bias=False),
            nn.Sigmoid()
        )
    def forward(self, input):
        return self.net(input)

def train_gen(train_loader, model, criterion, optimizer, lr_scheduler):
    loss_meter = AverageMeter('train loss')
    acc_meter = AverageMeter('train accuracy')
    miou_meter = AverageMeter('train mIOU')
    conf_meter = ConfusionMeter(cfg.NUM_CLASS)
    model.train()

    for i, (images, targets, _) in enumerate(tqdm(train_loader)):
        images = images.to(DEVICE)
        targets = targets.long().to(DEVICE)

        outputs = model(images)
        loss = criterion(outputs, targets)

        loss_meter.update(loss.item(), images.size(0))
        conf_meter.update(outputs.argmax(1), targets)
        metric = Metric(conf_meter.value())
        acc_meter.update(metric.accuracy())
        miou_meter.update(metric.miou())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

    return loss_meter.avg, acc_meter.avg, miou_meter.avg


def validate_gen(val_loader, model, criterion):
    loss_meter = AverageMeter('validation loss')
    acc_meter = AverageMeter('validation accuracy')
    miou_meter = AverageMeter('validation mIOU')
    conf_meter = ConfusionMeter(cfg.NUM_CLASS)

    with torch.no_grad():
        model.eval()
        for i, (images, targets, _) in enumerate(tqdm(val_loader)):
            images = images.to(DEVICE)
            targets = targets.long().to(DEVICE)

            outputs = model(images)
            loss = criterion(outputs, targets)

            loss_meter.update(loss.item(), images.size(0))
            conf_meter.update(outputs.argmax(1), targets)

            metric = Metric(conf_meter.value())
            acc_meter.update(metric.accuracy())
            miou_meter.update(metric.miou())

    return loss_meter.avg, acc_meter.avg, miou_meter.avg
    

class GAN:
    def __init__(self, n_critics=2):
        
        if cfg.MODEL == "unet":
            model = smp.Unet(
                encoder_name=cfg.MODEL_ENCODER,
                encoder_weights="imagenet",
                in_channels=3,
                classes=cfg.NUM_CLASS,
            ).to(DEVICE)
        else: raise ValueError(cfg.MODEL)
        
        self.model = model
        
        
        # model_dis is the discriminator network
        model_dis = Discriminator_Net()
        self.model_dis = model_dis
        
        # number of epochs of discriminator training for each epoch of generator training
        self.n_critics = n_critics
    
    def train():
        criterion = torch.nn.CrossEntropyLoss(ignore_index=-1).to(DEVICE)
        # TODO: discriminator net needs binary cross entropy?
        criterion_dis = torch.nn.BCELoss().to(DEVICE)
        # TODO: should add l1 loss to the generator loss?
        criterion_l1 = nn.L1Loss()

        if cfg.OPTIMIZER == "AdamW":
            optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.LR)
            optimizer_dis = torch.optim.AdamW(model_dis.parameters(), lr=cfg.LR)
        else: raise ValueError(cfg.OPTIMIZER)

        if cfg.LR_SCHEDULER == "CosineAnnealingWarmRestarts":
            scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=cfg.T_RESTART, eta_min=cfg.LR_MIN)
        else: raise ValueError(cfg.LR_SCHEDULER)

        train_loader, val_loader = get_dataloaders(cfg)
        
        
        recorder_gen = Recorder(["train_loss", "train_acc", "train_miou", "val_loss", "val_acc", "val_miou"])
        recorder_dis = Recorder(["train_loss", "train_acc", "train_miou", "val_loss", "val_acc", "val_miou"])
        
        for i in range(cfg.EPOCH):
            print("Epoch", i)
            
            
            train_loss, train_acc, train_miou = train_gen(train_loader, model, criterion, optimizer, scheduler)
            print("Gen train_loss:", train_loss)
            print("Gen train_acc:", train_acc)
            print("Gen train_miou:", train_miou)
            val_loss, val_acc, val_miou = validate_gen(val_loader, model, criterion)
            print("Gen val_loss:", val_loss)
            print("Gen val_acc:", val_acc)
            print("Gen val_miou:", val_miou)
            recorder_gen.update([train_loss, train_acc, train_miou, val_loss, val_acc, val_miou])

            torch.save(recorder_gen.record, f"./checkpoint/{save_folder}/gen/trace.log")
            if cfg.SAVE and val_miou > val_miou_:
                torch.save({
                    "epoch": i,
                    "model": model,
                    "optimizer": optimizer,
                    "scheduler": scheduler,
                }, f"./checkpoint/{save_folder}/gen/state.pth")
                val_miou_ = val_miou
                print("model saved.")
            
            
            
            # TRAIN DISCRIMINATOR
            for critic in range(n_critics):
                train_loss, train_acc, train_miou = train_dis(train_loader, model, model_dis, 
                                                              criterion, criterion_dis, 
                                                              criterion_l1, optimizer, 
                                                              optimizer_dis, scheduler)
                print("Dis train_loss:", train_loss)
                print("Dis train_acc:", train_acc)
                print("Dis train_miou:", train_miou)
                val_loss, val_acc, val_miou = validate_dis(val_loader, model, model_dis, criterion)
                print("Dis val_loss:", val_loss)
                print("Dis val_acc:", val_acc)
                print("Dis val_miou:", val_miou)
                recorder_dis.update([train_loss, train_acc, train_miou, val_loss, val_acc, val_miou])

                torch.save(recorder_dis.record, f"./checkpoint/{save_folder}/dis/trace.log")
                if cfg.SAVE and val_miou > val_miou_:
                    torch.save({
                        "epoch": i,
                        "model": model_dis,
                        "optimizer": optimizer,
                        "scheduler": scheduler,
                    }, f"./checkpoint/{save_folder}/dis/state.pth")
                    val_miou_ = val_miou
                    print("model saved.")
    
    
        
    
if __name__ == '__main__':
    experiment_file = "baseline_0.yml"
    cfg = get_config(f"./experiments/{experiment_file}")
    save_folder = experiment_file.split('.')[0] + "_" + cfg.MODEL
    if cfg.SAVE:
        if os.path.exists(f"./checkpoint/{save_folder}"):
            res = get_bool(f"./checkpoint/{save_folder} already exists. Overwrite? (y/n)")
            if not res: sys.exit(0)
        else:
            os.mkdir(f"./checkpoint/{save_folder}")
            cfg.dump(stream = open(f"./checkpoint/{save_folder}/config.yml", 'w'))

    np.random.seed(cfg.SEED)
    torch.manual_seed(cfg.SEED)

    

    gan = GAN()
    gan.train()
    
        
    

    




BATCH_SIZE: 8
CROP_SIZE: (512, 512)
DATASET_ROOT: ./data/cityscapes
EPOCH: 50
LR: 0.0001
LR_MIN: 1e-07
LR_SCHEDULER: CosineAnnealingWarmRestarts
MODEL: unet
MODEL_ENCODER: resnet50
NAME_SUFFIX: 
NUM_CLASS: 19
OPTIMIZER: AdamW
SAVE: True
SEED: 42
SHAPE: (1024, 2048)
SUBSET: False
SUBSET_SIZE: 8
TRAIN_REPEAT: 2
T_RESTART: 500
./checkpoint/baseline_0_unet already exists. Overwrite? (y/n)y


RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 4.00 GiB total capacity; 2.85 GiB already allocated; 1.22 MiB free; 2.87 GiB reserved in total by PyTorch)