![WSDL2022](logo.png)

<center> <font color = "Salmon" size = 6 > Segmentation Hands-on </font> </center>
<center> <font color = "DarkSlateBlue" size = 5 > Sankarsan Seal </font> </center>
<center> 12th February 2022 </center>

In [None]:
from modules import PSPNetModel
from modules import DataGenerator as dg
from modules import SavingParameterState as sps

from torch.optim import Adam
from torch.nn import CrossEntropyLoss

from torch.utils.data import DataLoader

from tqdm.notebook import tqdm

import numpy as np

from torchvision.transforms import functional as ttf

from matplotlib import pyplot as plt


In [None]:
CUDA_DEVICE = "cuda:0"

NO_OF_ITERATIONS = 1000

LEARNING_RATE = 1e-5

SPECIES_NAME = "Bengal"

BATCH_SIZE = 4

MODEL_SAVING_STRING = "PSPNet_epoch"

# PSPNet Architecture

The paper is available at https://arxiv.org/abs/1612.01105

![PSPNet](PSPNetScreenshot2022-02-12083857.png)

In [None]:
pspnet = PSPNetModel.PSPNet()
pspnet.to( device = CUDA_DEVICE )

In [None]:
optimizer = Adam( pspnet.parameters(), lr = LEARNING_RATE )
criterion = CrossEntropyLoss()

In [None]:
train_data_gen = dg.TrainDataGenerator( species_name = SPECIES_NAME )
train_dataloader = DataLoader( dataset = train_data_gen, 
                             batch_size = BATCH_SIZE,
                             shuffle = True,)

val_data_gen = dg.ValDataGenerator( species_name = SPECIES_NAME )
val_dataloader = DataLoader( dataset = val_data_gen,
                           batch_size = 1,
                           shuffle = False
                           )

test_data_gen = dg.TestDataGenerator( species_name = SPECIES_NAME )
test_dataloader = DataLoader( dataset = test_data_gen,
                            batch_size = 1,
                            shuffle = False
                            )

In [None]:
lowest_val_loss = np.inf
best_epoch = -1

for epoch in tqdm( range( NO_OF_ITERATIONS ) ):
    
    pspnet.train()
    
    total_loss = list()
    
    total_val_loss = list()
    
    for image_tensor, mask_tensor in train_dataloader:
        
        optimizer.zero_grad()
    
        segmentation_output = pspnet( image_tensor.to( device = CUDA_DEVICE ) )
        
        
        loss = criterion( segmentation_output , mask_tensor.to( device = CUDA_DEVICE )  )
        
        loss.backward()
        
        optimizer.step()
        
        total_loss.append( loss.item() )
        
        
    print("Total Loss after {0} epoch: {1}".format( epoch, np.mean( total_loss ) ))
    
    pspnet.eval()
    
    for image_tensor, mask_tensor in val_dataloader:
        segmentation_output = pspnet( image_tensor.to( device = CUDA_DEVICE ) )
        
        
        loss = criterion( segmentation_output , mask_tensor.to( device = CUDA_DEVICE )  )
        
        total_val_loss.append( loss.item() )
        
    mean_val_loss = np.mean( total_val_loss )
    
    if  mean_val_loss < lowest_val_loss :
        
        print( "   ***Lowest mean validation loss is {0} at epoch {1}".format( mean_val_loss, epoch ) )
        
        best_epoch = epoch
        
        lowest_val_loss = mean_val_loss
        
        sps.save_model_parameters( model = pspnet,
                                 optimizer = optimizer,
                                 name_of_the_model = MODEL_SAVING_STRING + "{0}".format( epoch )
                                 )
        
        
        
sps.save_model_parameters( model = pspnet,
                          optimizer = optimizer,
                          name_of_the_model = MODEL_SAVING_STRING + "{0}".format( -1 ))
        
        
        
        
        
    
    

In [None]:
sps.load_model_parameter( model = pspnet,
                        optimizer = optimizer,
                        name_of_the_model = MODEL_SAVING_STRING + "{0}".format( -1 ) 
                        )

In [None]:
pspnet.eval()
for i_index, ( image_tensor, mask_tensor ) in enumerate( test_dataloader ):
    
    print( i_index )
    
    segmentation_output = pspnet( image_tensor.to( device = CUDA_DEVICE ) )
    
    plt.subplot(121)

    plt.imshow( ttf.to_pil_image( image_tensor[0] ) )

    plt.subplot(122)
    plt.imshow( ttf.to_pil_image( segmentation_output[0] ) )
    plt.show()

In [None]:
best_epoch

In [None]:
sps.load_model_parameter( model = pspnet,
                        optimizer = optimizer,
                        name_of_the_model = MODEL_SAVING_STRING + "{0}".format( best_epoch ) 
                        )

In [None]:
pspnet.eval()
for (image_tensor, mask_tensor) in test_dataloader:
    
    segmentation_output = pspnet( image_tensor.to( device = CUDA_DEVICE ) )
    
    plt.subplot(121)

    plt.imshow( ttf.to_pil_image( image_tensor[0] ) )

    plt.subplot(122)
    plt.imshow( ttf.to_pil_image( segmentation_output[0] ) )
    plt.show()