In [None]:
import torch
import glob
from PIL import Image
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader,  TensorDataset, Dataset
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
import math

from os import listdir
import shutil
import time
from tqdm import tqdm

#from torchvision import models, datasets, transforms
!pip install git+https://github.com/qubvel/segmentation_models.pytorch
import segmentation_models_pytorch as smp

import imageio

import pickle

from google.colab import drive
drive.mount('/gdrive')

import random
import os
def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
set_random_seed(42)

In [None]:
# dataset creating 
class MyDataset(Dataset):
    
    def __init__(self, data, targets, transform=None):
        #self.data = data
        #self.targets = targets
        #self.data.sort()
        #self.targets.sort()
        #self.transform = transform

        self.dir_data = data
        self.dir_targets = targets

        self.list_data = listdir(data)
        self.list_targets = listdir(targets)

        self.list_data.sort()
        self.list_targets.sort()


        
    def __len__(self):
        if len(listdir(self.dir_data)) != len(listdir(self.dir_targets)):
            raise Exception(f'error {len(listdir(self.dir_data))} {len(listdir(self.dir_targets))}')
        return len(listdir(self.dir_data))

    def __getitem__(self, idx):
        #image = self.data[idx]
        #label = self.targets[idx]
        #if self.transform:
            #image = self.transform(image)
        #print(f'{self.dir_data}/{self.list_data[idx]}')
        with open(f'{self.dir_data}/{self.list_data[idx]}', 'rb') as f:
            image = pickle.load(f)

        with open(f'{self.dir_targets}/{self.list_targets[idx]}', 'rb') as f:
            label = pickle.load(f)

        return torch.from_numpy(image).float(), torch.from_numpy(label)

In [None]:
# dataset balanced
dir_img = '/gdrive/My Drive/Segmentation_project/new_datasets/train_val_images_cropped_balanced'
dir_mask = '/gdrive/My Drive/Segmentation_project/new_datasets/train_val_masks_cropped_balanced'

dataset_balanced = MyDataset(data=dir_img, targets=dir_mask)

#print(len(dataset_balanced))

train_dataset, val_dataset = torch.utils.data.random_split(dataset_balanced, [int(5200*0.8), int(5200*0.2)])
train_loader = DataLoader(train_dataset, batch_size = 26, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 26, shuffle = False)

In [None]:
# IoU validation metrics
def validate_iou(model, val_loader, device):
    metrics = []
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)
        #print(labels.shape, 'labels')
        pred_mask = model(images).to(device)
        pred_mask = (pred_mask > 0.5).float()
        #print(pred_mask.shape, 'pred')
        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), labels.long(), mode="binary")
        metrics.append(smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise").detach().cpu().numpy())
    
    #print(metrics)
    return np.array(metrics).mean()

In [None]:
# train function
def train(model, num_epochs):
    
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #print('Using device:', device)
    
    device = torch.device("cuda")
    
    sum_acc = np.zeros((1,  num_epochs))
    sum_loss = sum_acc.copy()
    model.train()
    model.to(device)

    criterion = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    best_accuracy = 0
    best_loss = 100

    for epoch in tqdm(range(num_epochs)):
        epoch_start = time.time()
        num = random.randint(0, 360)
        for img_batch, labels_batch in (train_loader):
            # color
            imag_batch = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3)
            
            # rotation
            #imag_batch = torchvision.transforms.functional.rotate(img_batch, num)
            #labels_batch = torchvision.transforms.functional.rotate(labels_batch, num)
            
            optimizer.zero_grad()
            output = model(img_batch.to(device))            
            loss = criterion(output, labels_batch.to(device).unsqueeze(1))
            loss.backward()
            optimizer.step()

        accuracy = validate_iou(model, val_loader, device)

        if best_accuracy < accuracy:
            best_accuracy = accuracy
            print('Best metrics')
            torch.save(model.state_dict(), '/gdrive/My Drive/TyurinaAV_segmentation/Learning/model_weights.pth') #name of saved weights
        if best_loss > loss.cpu().item():
            best_loss = loss.cpu().item()
#             print('Best loss improved')

        sum_acc[0, epoch] = accuracy
        sum_loss[0, epoch] = loss  
        epoch_end = time.time()
        print("Epoch: {} Loss: {:.3f} IoU: {:.3f} Time: {:.4f}s".format(epoch+1, loss.item(), accuracy, epoch_end-epoch_start))
        
    
    return sum_acc, sum_loss

In [None]:
model = smp.Unet(encoder_name='efficientnet-b4', 
                 encoder_depth=5, 
                 encoder_weights='imagenet', 
                 decoder_use_batchnorm=True, 
                 decoder_channels=(256, 128, 64, 32, 16), 
                 decoder_attention_type=None, 
                 in_channels=3, 
                 classes=1, 
#                 #activation='sigmoid', 
                 aux_params=None)


accuracy, loss = train(model, 100)

In [None]:
#np.save('/gdrive/My Drive/Segmentation_project/new_datasets/UNETacc_efficientnet-b4_batch=26_100ep_NEW_AUG.npy', accuracy)
#np.save('/gdrive/My Drive/Segmentation_project/new_datasets/UNETloss_efficientnet-b4_batch=26_100ep_NEW_AUG.npy', loss)

In [None]:
model = smp.DeepLabV3Plus(encoder_name='efficientnet-b4', 
                          encoder_depth=5,
                          encoder_weights='imagenet',
                          encoder_output_stride=16,
                          decoder_channels=256,
                          decoder_atrous_rates=(12, 24, 36),
                          in_channels=3,
                          classes=1,
                          activation=None,
                          upsampling=4,
                          aux_params=None)

accuracy, loss = train(model, 100)