In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.io import read_image

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from IPython.display import clear_output

In [None]:
def image_transform(image):
    image = (image.float() / 255.0 - 0.5) / 0.5
    return image

def mask_transform(mask):
    new_mask = torch.zeros(mask[0].shape, dtype=torch.int64)
    new_mask[torch.logical_and(torch.logical_and(mask[0, :, :] >= 200, mask[1, :, :] >= 200), mask[2, :, :] >= 200)] = 1
    mask = new_mask
    return mask

train_image_transforms = torch.nn.Sequential(
    transforms.Resize(512, antialias=True)
)

val_image_transforms = torch.nn.Sequential(
    transforms.Resize(512, antialias=True)
)

mask_transforms = torch.nn.Sequential(
    transforms.Resize(512, antialias=True)
)

In [None]:
class ForestRoadsDataset(Dataset):

    def __init__(self, csv_file, root_dir, img_transform=None, m_transform=None, max_len=None):
        
        self.image_mask_df = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.img_transform = img_transform
        self.m_transform = m_transform
        self.max_len = len(self.image_mask_df)
        if max_len != None:
            self.max_len = max_len

    def __len__(self):
        return self.max_len

    def __getitem__(self, index):

        img_name = os.path.join(self.root_dir,
                                self.image_mask_df.iloc[index, 0])
        m_name = os.path.join(self.root_dir,
                                self.image_mask_df.iloc[index, 1])
        
        image = read_image(img_name)
        if self.img_transform:
            image = self.img_transform(image)
            image = image_transform(image)
            
        mask = read_image(m_name)
        if self.m_transform:
            mask = self.m_transform(mask)
            mask = mask_transform(mask)
            
        return image, mask

In [None]:
train_dataset = ForestRoadsDataset('/kaggle/input/coursework/data/train/train_info.csv', 
                                   '/kaggle/input/coursework/data/train', 
                                   train_image_transforms, mask_transforms, 8000)
val_dataset = ForestRoadsDataset('/kaggle/input/coursework/data/validation/validation_info.csv', 
                                  '/kaggle/input/coursework/data/validation', 
                                  val_image_transforms, mask_transforms)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, pin_memory=True, num_workers=2)

In [None]:
def plot_losses(train_losses, val_losses, train_accuracies, val_accuracies, recalls=None, precisions=None, qs=None, f1s=None):
    clear_output()
    fig, axs = plt.subplots(2, 2, figsize=(13, 8))
    axs[0, 0].plot(range(1, len(train_losses) + 1), train_losses, label='train')
    axs[0, 0].plot(range(1, len(val_losses) + 1), val_losses, label='validation')
    axs[0, 0].set_ylabel('loss')
    axs[0, 0].set_xlabel('epoch')
    axs[0, 0].legend()

    axs[0, 1].plot(range(1, len(train_accuracies) + 1), train_accuracies, label='train')
    axs[0, 1].plot(range(1, len(val_accuracies) + 1), val_accuracies, label='validation')
    axs[0, 1].set_ylabel('accuracy')
    axs[0, 1].set_xlabel('epoch')
    axs[0, 1].legend()
    
    
    if recalls is not None:
        axs[1, 0].plot(range(1, len(recalls) + 1), recalls, label='recall')
    if precisions is not None:
        axs[1, 0].plot(range(1, len(precisions) + 1), precisions, label='precision')
    if qs is not None:
        axs[1, 0].plot(range(1, len(qs) + 1), qs, label='iou')
    if f1s is not None:
        axs[1, 0].plot(range(1, len(f1s) + 1), f1s, label='f1')
    axs[1, 0].set_ylabel('validation metrics')
    axs[1, 0].set_xlabel('epoch')
    axs[1, 0].legend()
    
    plt.show()

def training_epoch(model, optimizer, criterion, train_loader, tqdm_desc):
    train_loss, train_accuracy = 0.0, 0.0
    model.train()
    for images, masks in tqdm(train_loader, desc=tqdm_desc):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        logits = model(images)['out']
        loss = criterion(logits, masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.shape[0]
        train_accuracy += (logits.argmax(dim=1) == masks).sum().item()/(images.shape[-1]**2)
    
    train_loss /= len(train_loader.dataset)
    train_accuracy /= len(train_loader.dataset)
    return train_loss, train_accuracy


@torch.no_grad()
def validation_epoch(model, criterion, val_loader, tqdm_desc):
    val_loss, val_accuracy = 0.0, 0.0
    tp, fn, tn, fp = 0.0, 0.0, 0.0, 0.0
    
    model.eval()
    
    for images, masks in tqdm(val_loader, desc=tqdm_desc):
        images = images.to(device)
        masks = masks.to(device)
        
        logits = model(images)['out']
        loss = criterion(logits, masks)

        val_loss += loss.item() * images.shape[0]
        
        labels = logits.argmax(dim=1)
        labels_masks = (labels == masks)
        labels_not_masks = torch.logical_not(labels_masks)
        labels0 = (labels == 0)
        labels1 = torch.logical_not(labels0)
        
        val_accuracy += labels_masks.sum().item()/(images.shape[-1]**2)
        tp += torch.logical_and(labels_masks, labels1).sum().item()
        fn += torch.logical_and(labels_not_masks, labels0).sum().item()
        tn += torch.logical_and(labels_masks, labels0).sum().item()
        fp += torch.logical_and(labels_not_masks, labels1).sum().item()        

    val_loss /= len(val_loader.dataset)
    
    val_accuracy /= len(val_loader.dataset)
    recall = tp / (tp + fn + 1e-16)        
    precision = tp / (tp + fp + 1e-16)
    q = tp / (tp + fn + fp + 1e-16)
    f1 = 2*tp / (2*tp + fn + fp + 1e-16)
 
    return val_loss, val_accuracy, recall, precision, q, f1

    
def train(model, optimizer, scheduler, criterion, train_loader, val_loader, num_epochs):
    train_losses, train_accuracies = [], []
    val_losses, val_accuracies = [], []
    recalls, precisions, qs, f1s = [], [], [], []

    for epoch in range(1, num_epochs + 1):
        train_loss, train_accuracy = training_epoch(
            model, optimizer, criterion, train_loader,
            tqdm_desc=f'Training {epoch}/{num_epochs}'
        )
        val_loss, val_accuracy, recall, precision, q, f1 = validation_epoch(
            model, criterion, val_loader,
            tqdm_desc=f'Validating {epoch}/{num_epochs}'
        )

        if scheduler is not None:
            scheduler.step()

        train_losses += [train_loss]
        train_accuracies += [train_accuracy]
        val_losses += [val_loss]
        val_accuracies += [val_accuracy]
        recalls += [recall]
        precisions += [precision]
        qs += [q]
        f1s += [f1]
        plot_losses(train_losses, val_losses, train_accuracies, val_accuracies, recalls, precisions, qs, f1s)
        
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, '/kaggle/working/checkpoint')
        
    return train_losses, val_losses, train_accuracies, val_accuracies, recalls, precisions, qs, f1s

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

In [None]:
num_epochs = 30
model = models.segmentation.fcn_resnet50(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=0.001)
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.5]).to(device))
scheduler = None

train_losses, val_losses, train_accuracies, val_accuracies, recalls, precisions, qs, f1s = train(
    model, optimizer, scheduler, criterion, train_loader, val_loader, num_epochs
)

In [None]:
# model = models.segmentation.fcn_resnet50(num_classes=2).to(device)

# checkpoint = torch.load('/kaggle/working/checkpoint')
# model.load_state_dict(checkpoint['model_state_dict'])

# model.eval()

In [None]:
@torch.no_grad()
def plot_results(dataset, n_examples):
    
    fig, ax = plt.subplots(n_examples, 3, figsize=(15, 5 * n_examples))
    
    model.eval()
    
    inds = np.random.choice(len(dataset), n_examples)
    
    for i in range(n_examples):
        img, m = dataset[inds[i]]
    
        ax[i, 0].imshow((img * 0.5 + 0.5).permute(1, 2, 0))
        ax[i, 0].set_title('train image №' + str(inds[i]))
        
        ax[i, 1].imshow(m, cmap='gray', vmin=0, vmax=1)
        ax[i, 1].set_title('train ground truth №' + str(inds[i]))
            
        img = img.unsqueeze(0).to(device)
        predicted_m = model(img)['out'].argmax(dim=1).squeeze().cpu()
        
        ax[i, 2].imshow(predicted_m, cmap='gray', vmin=0, vmax=1)
        ax[i, 2].set_title('train predicted mask №' + str(inds[i]))
    
    for i in range(n_examples):
        for j in range(3):            
            ax[i, j].set_xticks([])
            ax[i, j].set_yticks([])
    
    plt.plot()

In [None]:
plot_results(train_dataset, n_examples=3)

In [None]:
plot_results(val_dataset, n_examples=3)

In [None]:
test_dataset = ForestRoadsDataset('/kaggle/input/coursework/data/test/test_info.csv', 
                                   '/kaggle/input/coursework/data/test', 
                                   val_image_transforms, mask_transforms)

test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, pin_memory=True, num_workers=2)

In [None]:
with torch.no_grad():
    tp, fn, tn, fp = 0, 0, 0, 0
    
    model.eval()
    
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        logits = model(images)['out']
        
        labels = logits.argmax(dim=1)
        labels_masks = (labels == masks)
        labels_not_masks = torch.logical_not(labels_masks)
        labels0 = (labels == 0)
        labels1 = torch.logical_not(labels0)
        
        tp += torch.logical_and(labels_masks, labels1).sum().item()
        fn += torch.logical_and(labels_not_masks, labels0).sum().item()
        tn += torch.logical_and(labels_masks, labels0).sum().item()
        fp += torch.logical_and(labels_not_masks, labels1).sum().item()        
    
accuracy = ((tp + tn) / (test_dataset[0][0].shape[-1]**2)) / len(test_loader.dataset)
recall = tp / (tp + fn + 1e-16)        
precision = tp / (tp + fp + 1e-16)
q = tp / (tp + fn + fp + 1e-16)
f1 = 2*tp / (2*tp + fn + fp + 1e-16)

print('accuracy: %0.5f' % accuracy)
print('recall: %0.5f' % recall)
print('precision: %0.5f' % precision)
print('iou: %0.5f' % q)
print('f1: %0.5f' % f1)

In [None]:
plot_results(test_dataset, n_examples=3)

In [None]:
torch.save(model, '/kaggle/working/fcn')

In [None]:
model = torch.load('/kaggle/working/fcn')
model.eval()

In [None]:
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import QConfigMapping

device = torch.device('cpu')
model = model.to(device)
model.eval()

qconfig = get_default_qconfig('qnnpack')
qconfig_mapping = QConfigMapping().set_global(qconfig)
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
example_inputs = (next(iter(train_loader))[0]) 
prepared_model = prepare_fx(model, qconfig_mapping, example_inputs)  
calibrate(prepared_model, test_loader)  
quantized_model = convert_fx(prepared_model)  

In [None]:
quantized_model(test_dataset[0][0].unsqueeze(0))

In [None]:
model = quantized_model

In [None]:
with torch.no_grad():
    tp, fn, tn, fp = 0, 0, 0, 0
    
    model.eval()
    
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)

        logits = model(images)['out']
        
        labels = logits.argmax(dim=1)
        labels_masks = (labels == masks)
        labels_not_masks = torch.logical_not(labels_masks)
        labels0 = (labels == 0)
        labels1 = torch.logical_not(labels0)
        
        tp += torch.logical_and(labels_masks, labels1).sum().item()
        fn += torch.logical_and(labels_not_masks, labels0).sum().item()
        tn += torch.logical_and(labels_masks, labels0).sum().item()
        fp += torch.logical_and(labels_not_masks, labels1).sum().item()        
    
accuracy = ((tp + tn) / (test_dataset[0][0].shape[-1]**2)) / len(test_loader.dataset)
recall = tp / (tp + fn + 1e-16)        
precision = tp / (tp + fp + 1e-16)
q = tp / (tp + fn + fp + 1e-16)
f1 = 2*tp / (2*tp + fn + fp + 1e-16)

print('accuracy: %0.5f' % accuracy)
print('recall: %0.5f' % recall)
print('precision: %0.5f' % precision)
print('iou: %0.5f' % q)
print('f1: %0.5f' % f1)

In [None]:
plot_results(test_dataset, n_examples=3)

In [None]:
from torch.utils.mobile_optimizer import optimize_for_mobile

device = torch.device('cpu')
model = model.to(device)
model.eval()

dummy_input = torch.rand(1, 3, 512, 512).to(device)

torchscript_model = torch.jit.trace(model, dummy_input, strict=False)
optimized_torchscript_model = optimize_for_mobile(torchscript_model)
optimized_torchscript_model._save_for_lite_interpreter("optimized_torchscript_model_fcn_quant")