# FCN VOC 2012 and SBD Semi Supervised Learning

In [None]:
%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torchvision
from torchvision import models
import torch.utils.data as tud
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm import tqdm
from PIL import Image
from collections import Counter
from sklearn.metrics import jaccard_score
import pickle
import my_datasets as mdset
import eval_train as ev
from utils import * 




## Dataset : Pascal VOC 2012

In [None]:
dataroot_voc = '/data/voc2012'
dataroot_sbd = '/data/sbd'
model_name = 'fcn_voc_sbd30_semisup_g05_CE'
SAVE_DIR = '/data/model'
save = os.path.join(SAVE_DIR,model_name)
batch_size = 2
gamma = 0.5
save_all_ep = True
#criterion_unsupervised = nn.L1Loss(reduction='mean')
#criterion_unsupervised = nn.KLDivLoss(reduction = 'batchmean', log_target = False)
criterion_unsupervised = nn.CrossEntropyLoss(ignore_index=21)
criterion_supervised = nn.CrossEntropyLoss(ignore_index=21) # On ignore la classe border.
Loss = 'CE' # Loss = 'KL' or 'CE' or None for L1,MSE…
rotate = False
split = True
fully_supervised = True # Use the same dataloader for equivariance loss and supervised loss
n_epochs_supervised = 25 # Train in fully supervised for n_epoch_supervised

In [None]:
train_dataset_VOC = mdset.VOCSegmentation(dataroot_voc,year='2012', image_set='train', download=True,rotate=rotate)
val_dataset_VOC = mdset.VOCSegmentation(dataroot_voc,year='2012', image_set='val', download=True)
train_dataset_SBD = mdset.SBDataset(dataroot_sbd, image_set='train_noval',mode='segmentation',rotate=rotate)

### Concatene Dataset

In [None]:
train_dataset_unsup = tud.ConcatDataset([train_dataset_VOC,train_dataset_SBD])

### Split dataset

In [None]:
if split:
    train_dataset_sup = split_dataset(train_dataset_unsup,0.3)

In [None]:
if fully_supervised : 
    train_dataset_unsup = train_dataset_sup
    
dataloader_train_sup = torch.utils.data.DataLoader(train_dataset_sup, batch_size=batch_size,\
                                                       shuffle=True,drop_last=True)
dataloader_train_equiv = torch.utils.data.DataLoader(train_dataset_unsup,batch_size=batch_size,\
                                                     shuffle=True,drop_last=True)

dataloader_val = torch.utils.data.DataLoader(val_dataset_VOC, batch_size=batch_size)
# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device :",device)



In [None]:
print("Taille dataset train supervised :",len(train_dataset_sup))
print("Taille dataset train unsupervised :",len(train_dataset_unsup))
print("Taille dataset val VOC :",len(val_dataset_VOC))


## FCN Pytorch

In [None]:
def load_model(file=None,fcn=False,pretrained=False):
    if file is None:
        if fcn is False:
            model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=pretrained)
        else:
            model = torchvision.models.segmentation.fcn_resnet101(pretrained=pretrained)
    else:
        model = torch.load(os.path.join(SAVE_DIR,file))
    return model

In [None]:
model = load_model(fcn=True,pretrained=False)

In [None]:
model.to(device)

## Training


In [None]:
learning_rate = 10e-4
moment = 0.9
wd = 2e-4
n_epochs = 26
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate,momentum=moment, weight_decay=wd)
angle_max = 30


In [None]:
iou_train = []
iou_test = []
combine_loss_train = []
combine_loss_test = []
loss_train_unsup = []
loss_train_sup = []
loss_test = []
loss_test_unsup = []
pix_accuracy_train = []
pix_accuracy_test = []
accuracy_test = []
accuracy_train = []
#
all_combine_loss_train = []
all_loss_train_sup = []
all_loss_train_unsup = []
all_iou_train= []
all_pix_accuracy =  []

### Pretrain fully supervised 

In [None]:
## pretrain the model in fully supervised
torch.autograd.set_detect_anomaly(True)
for ep in range(n_epochs_supervised):
    print("EPOCH",ep)
    model.train()
    for i,(x,mask) in enumerate(dataloader_train_sup):
        x = x.to(device)
        mask = mask.to(device)  
        pred = model(x)
        pred = pred["out"]
        loss = criterion_supervised(pred,mask)
        all_loss_train_sup.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        all_iou_train.append(inter_over_union(pred.argmax(dim=1).detach().cpu(),mask.detach().cpu()))
        optimizer.step()        
        
    #lr_scheduler.step()
    m_iou = np.array(all_iou_train).mean()
    m_loss = np.array(all_loss_train_sup).mean()
    loss_train_sup.append(m_loss)
    iou_train.append(m_iou)
    all_loss_train_sup = []
    all_iou_train = []
    print("EP:",ep," loss train:",m_loss," iou train:",m_iou)
    
    #Eval model
    
    model.eval()
    state = ev.eval_model(model,dataloader_val,device=device,num_classes=21)
    iou = state.metrics['mean IoU']
    acc = state.metrics['accuracy']
    loss = state.metrics['CE Loss'] 
    loss_test.append(loss)
    iou_test.append(iou)
    accuracy_test.append(acc)
    print('EP:',ep,'iou:',state.metrics['mean IoU'],\
          'Accuracy:',state.metrics['accuracy'],'Loss CE',state.metrics['CE Loss'])
    
    torch.save(model,save)
    ## Save model
    if save_all_ep:
        save_model = model_name+'_pretrain_sup'+'_ep'+str(ep)+'.pt'
        save = os.path.join(SAVE_DIR,save_model)
        torch.save(model,save)
    else:
        save_model = model_name+'_pretrain_sup'+'.pt'
        save = os.path.join(SAVE_DIR,save_model)
        torch.save(model,save)
    
m_pix_acc, m_loss_equiv = eval_accuracy_equiv(model,dataloader_val,\
                    criterion=criterion_unsupervised,nclass=21,device=device,plot=False) 
d_iou = ev.eval_model_all_angle(model,batch_size=batch_size,device=device,num_classes=21)
for k in d_iou.keys():
    print('Scores for datasets rotate by',k,'degrees:')
    print('   mIoU',d_iou[k]['mIoU'],'Accuracy',d_iou[k]['Accuracy'],'CE Loss',d_iou[k]['CE Loss'])

### Semi sup training

In [None]:
learning_rate = 2*10e-4

In [None]:
torch.autograd.set_detect_anomaly(True)
for ep in range(n_epochs):
    #dataloader_train_sup = torch.utils.data.DataLoader(train_dataset_sup, batch_size=batch_size,\
                                                      # shuffle=True,drop_last=True)
    #dataloader_train_equiv = torch.utils.data.DataLoader(train_dataset_unsup,batch_size=batch_size,\
                                                     #shuffle=True,drop_last=True)
    print("EPOCH",ep)
    model.train()

    for batch_sup,batch_unsup in zip(dataloader_train_sup,dataloader_train_equiv):
        optimizer.zero_grad()
        if random.random() > 0.5: # I use this to rotate the image on the left and on the right during training.
            angle = np.random.randint(0,angle_max)
        else:
            angle = np.random.randint(360-angle_max,360)
        x_unsup,_ = batch_unsup
        loss_equiv,acc = compute_transformations_batch(x_unsup,model,angle,reshape=False,\
                                                     criterion=criterion_unsupervised,Loss = Loss,\
                                                       device=device)
        x,mask = batch_sup
        x = x.to(device)
        mask = mask.to(device)
        pred = model(x)["out"]
        loss_equiv = loss_equiv.to(device) # otherwise bug in combining the loss 
        loss_sup = criterion_supervised(pred,mask)
        loss = gamma*loss_sup + (1-gamma)*loss_equiv # combine loss              
        loss.backward()
        optimizer.step()

        # append for plot
        all_pix_accuracy.append(acc) # accuracy between the original mask and the transform mask put back in place
        all_loss_train_unsup.append(loss_equiv.item())
        all_loss_train_sup.append(loss_sup.item())
        all_combine_loss_train.append(loss.item())
            
    #lr_scheduler.step()
    #
    m_loss_combine = np.array(all_combine_loss_train).mean()
    m_acc = np.array(all_pix_accuracy).mean()
    combine_loss_train.append(m_loss_combine)
    pix_accuracy_train.append(m_acc)
    loss_train_sup.append(np.array(all_loss_train_sup).mean())
    loss_train_unsup.append(np.array(all_loss_train_unsup).mean())

    all_pix_accuracy = []
    all_loss_train_unsup = []
    all_loss_train_sup = []
    all_combine_loss_train = []
    print("loss sup :",loss_sup.item(),"loss unsup",loss_equiv.item(),"loss",loss.item()) 
    print("EP:",ep," combine loss train:",m_loss_combine," pixel accuracy between masks ",m_acc)

    ## Evaluate the  model
    model.eval()
    state = ev.eval_model(model,dataloader_val,device=device,num_classes=21)
    iou = state.metrics['mean IoU']
    acc = state.metrics['accuracy']
    loss = state.metrics['CE Loss']
    print('EP:',ep,'iou:',state.metrics['mean IoU'],\
          'Accuracy:',state.metrics['accuracy'],'Loss CE',state.metrics['CE Loss'])
    loss_test.append(loss)
    iou_test.append(iou)
    accuracy_test.append(acc)
    if ep%3==0:
        m_pix_acc, m_loss_equiv = eval_accuracy_equiv(model,dataloader_val,\
                    criterion=criterion_unsupervised,nclass=21,device=device,plot=False)
        loss_test_unsup.append(m_loss_equiv)
        pix_accuracy_test.append(m_pix_acc)
        
    if ep%5==0:
        d_iou = ev.eval_model_all_angle(model,batch_size=batch_size,device=device,num_classes=21)
        for k in d_iou.keys():
            print('Scores for datasets rotate by',k,'degrees:')
            print('   mIoU',d_iou[k]['mIoU'],'Accuracy',d_iou[k]['Accuracy'],'CE Loss',d_iou[k]['CE Loss'])
        
        
    ## Save model
    if save_all_ep:
        save_model = model_name+'_ep'+str(ep)+'.pt'
        save = os.path.join(SAVE_DIR,save_model)
        torch.save(model,save)
    else:
        save_model = model_name+'.pt'
        save = os.path.join(SAVE_DIR,save_model)
        torch.save(model,save)
    
m_pix_acc, m_loss_equiv = eval_accuracy_equiv(model,dataloader_val,\
                    criterion=criterion_unsupervised,nclass=21,device=device,plot=False) 
d_iou = ev.eval_model_all_angle(model,batch_size=batch_size,device=device,num_classes=21)
for k in d_iou.keys():
    print('Scores for datasets rotate by',k,'degrees:')
    print('   mIoU',d_iou[k]['mIoU'],'Accuracy',d_iou[k]['Accuracy'],'CE Loss',d_iou[k]['CE Loss'])

## Plot

In [None]:
plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Combine loss train")
plt.plot(combine_loss_train)
plt.xlabel("iterations")
plt.ylabel("Loss")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Equivariance loss train")
plt.plot(loss_train_unsup)
plt.xlabel("iterations")
plt.ylabel("Loss")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. CE loss train")
plt.plot(loss_train_sup)
plt.xlabel("iterations")
plt.ylabel("Loss")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Equivariance Accuracy train")
plt.plot(pix_accuracy_train)
plt.xlabel("iterations")
plt.ylabel("Accuracy")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Mean iou train ")
plt.plot(iou_train)
plt.xlabel("iterations")
plt.ylabel("Mean IOU")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Cross entropy loss test")
plt.plot(loss_test)
plt.xlabel("iterations")
plt.ylabel("Loss")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Equivariance loss test")
plt.plot(loss_test_unsup)
plt.xlabel("iterations")
plt.ylabel("Loss")

plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Equivariance accuracy test")
plt.plot(pix_accuracy_test)
plt.xlabel("iterations")
plt.ylabel("Loss")


plt.figure(figsize=(10,8))
plt.subplot(2,1,1)
plt.title("FCN L1 semi sup. Mean iou test")
plt.plot(iou_test)
plt.xlabel("iterations")
plt.ylabel("Mean IOU")