In [1]:
import torch
import glob
from PIL import Image
import torch.optim as optim
from torch.utils.data import DataLoader,  TensorDataset, Dataset
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt
import math

#import os
import shutil
import time
from tqdm import tqdm

#from torchvision import models, datasets, transforms

import segmentation_models_pytorch as smp

In [2]:
class MyDataset(Dataset):
    
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.data.sort()
        self.targets.sort()
        self.transform = transform
        
    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        

        return image, label

In [3]:
# load data
data = np.load('train_val_cropped dataset_0804_wb.npz')

In [None]:
#data['arr_0'][0][0]

In [4]:
x_images = torch.from_numpy(data['arr_0']/255)
print(x_images.shape)
x_images = torch.transpose(x_images, 1, 3)
print(x_images.shape)
y_labels = torch.from_numpy(data['arr_1'])
print(y_labels.shape)
y_labels = torch.transpose(y_labels, 1, 3) 
print(y_labels.shape)

torch.Size([3428, 256, 256, 3])
torch.Size([3428, 3, 256, 256])
torch.Size([3428, 256, 256, 3])
torch.Size([3428, 3, 256, 256])


In [5]:
len_dataset = len(y_labels) 

dataset = MyDataset(data=x_images.float(), targets=y_labels)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [int(len_dataset*0.8), len_dataset-int(len_dataset*0.8)])

train_loader = DataLoader(train_dataset, batch_size = 6, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 6, shuffle = False)

del(data)

In [6]:
def validate(model, valloader, device):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valloader:
            images = images.to(device)
            labels = labels.to(device)
            #print(labels.shape, 'labels')
            outputs = model(images)
            #print(outputs.shape, 'outputs_val')
            #_, predicted = torch.max(outputs.data, 1)
            predicted = outputs
            
            #labels=labels.squeeze(1)
            #labels=labels.reshape(labels, )
            #print(labels.shape, 'labels')
            #print(predicted.shape, 'predicted')
            
            #predicted[0].show()
            total += labels.size(0)
            a = (predicted == labels).sum().item()
            correct += a
            
            # print(correct / total)

    return correct / total

In [14]:
def train(model, num_epochs=50):
    
    #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #print('Using device:', device)
    
    device = torch.device("cpu")
    
    sum_acc = np.zeros((1,  num_epochs))
    sum_loss = sum_acc.copy()
    model.train()
    model.to(device)

    criterion = smp.losses.DiceLoss('multiclass')
    #
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    best_accuracy = 0
    best_loss = 100
    #fig = plt.figure() 

    for epoch in tqdm(range(num_epochs)):
        epoch_start = time.time()

        for img_batch, labels_batch in tqdm(train_loader):
            optimizer.zero_grad()
            #print(img_batch.shape)
            #print(labels_batch.shape)
        
            output = model(img_batch.to(device))
            
            #print(output.shape, 'output')
    
            #loss = criterion(output, labels_batch.to(device).squeeze().long())
            loss = criterion(output, torch.argmax(labels_batch.to(device).squeeze().long(),dim = 1))
            loss.backward()
            optimizer.step()
            images = img_batch.cpu()
            label_nums = output.cpu() 
            accuracy = validate(model, val_loader, device)
#             print(accuracy)
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            print('Best accuracy improved')
            torch.save(model.state_dict(), 'model_weights.pth') #name of saved weights
        if best_loss > loss.cpu().item():
            best_loss = loss.cpu().item()
#             print('Best loss improved')

        sum_acc[0, epoch] = accuracy
        sum_loss[0, epoch] = loss  
        epoch_end = time.time()
        print("Epoch: {} Loss: {:.3f} Accuracy: {:.3f} Time: {:.4f}s".format(epoch, loss.item(), accuracy, epoch_end-epoch_start))
        
    
    return sum_acc, sum_loss

In [15]:
model = smp.Unet(encoder_name='resnet34', 
                 encoder_depth=5, 
                 encoder_weights='imagenet', 
                 decoder_use_batchnorm=True, 
                 decoder_channels=(256, 128, 64, 32, 16), 
                 decoder_attention_type=None, 
                 in_channels=3, 
                 classes=3, 
                 activation='softmax2d', 
                 aux_params=None)

accuracy, loss = train(model, 2)

  0%|          | 0/2 [00:00<?, ?it/s]
  0%|          | 0/457 [00:01<?, ?it/s][A
  0%|          | 0/2 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [None]:
#print(model)

In [None]:
# два класса +, DiceLoss +
# поискать обученные модели для микроскопии
# efficiency/mobileNet
# kaggle применение unet, зарегаться туда для обучения
# можно поменять /255 на /127.5-1
# аугментации (поворот, растяжение, цвет)
# дописать лит обзор про гистотрипсию и применение сетей

In [None]:
x = torch.zeros(1, 2, 3, 4, 5)