In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd 

In [2]:
data_path = "/kaggle/input/hamamatsu-lg-binary/All"

In [3]:
import cv2
import numpy as np
lower = np.array([0, 3, 240])
upper = np.array([255, 255, 255])


In [4]:
all_data = os.listdir(data_path)
image_names = [image.split(".")[0] for image in  all_data if image.split(".")[-1]=="tiff"]

In [5]:
len(image_names)

1200

In [6]:
class SegmentationDataset(Dataset):
    
    def __init__(self,image_path, images,feature_extractor):
        """_summary_

        Args:
            image_path (string): _Path of the Images_
            images (_type_): _Array of name of images_
            feature_extractor (_type_): _Image feature extractor for Segformer_
        """
        
        self.image_path = image_path
        self.images = images
        self.feature_extractor = feature_extractor
        
        
    def __len__(self):
        """_summary_

        Returns:
            _int_: _length of dataset_
        """
        return len(self.images)
    
    def __getitem__(self,idx):


        image_name = self.images[idx]
        image = Image.open(self.image_path +"/" +image_name +".tiff")
        seg = cv2.imread(self.image_path +"/" +image_name +"_mitosis.jpg")
        image = cv2.cvtColor(seg, cv2.COLOR_BGR2RGB)
    
        mask = cv2.inRange(image, lower, upper)
        segmentation_mask = Image.fromarray(mask)
        ## Encoding the data using feature extractor
        encodings = self.feature_extractor(image,segmentation_mask,return_tensors = "pt")
        
        ## Removing the dimension of the batch
        for k,v in encodings.items():
            encodings[k] = v.squeeze_()  
            
        return encodings

In [7]:
from transformers import SegformerFeatureExtractor


feature_extractor = SegformerFeatureExtractor()

In [8]:
train_dataset = SegmentationDataset(data_path,image_names,feature_extractor)

In [9]:
train_dataset[0]['labels'].squeeze().unique()

tensor([  0, 255])

In [10]:
for i in train_dataset:
    if (len(i['labels'].squeeze().unique()) != 1):
        print(i['labels'].squeeze().unique())
        break

tensor([  0, 255])


In [11]:
for k,v in train_dataset[0].items():
    print(v.size())

torch.Size([3, 512, 512])
torch.Size([512, 512])


In [12]:
id2label = {0:0,1:255}
label2id = {0:0,255:1}

In [13]:
from transformers import SegformerForSemanticSegmentation
model_name = "nvidia/mit-b0"
model = SegformerForSemanticSegmentation.from_pretrained(model_name,
                                                         num_labels=2, 
                                                         id2label=id2label, 
                                                         label2id=label2id,
)

Downloading:   0%|          | 0.00/68.4k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.7M [00:00<?, ?B/s]

Some weights of the model checkpoint at nvidia/mit-b0 were not used when initializing SegformerForSemanticSegmentation: ['classifier.weight', 'classifier.bias']
- This IS expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SegformerForSemanticSegmentation from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b0 and are newly initialized: ['decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.linear_c.1.proj.bias', 'deco

In [14]:
from torch.utils.data import DataLoader
batch_size = 16
len(image_names)

1200

In [15]:
import random
random.shuffle(image_names)

In [16]:
train_dataset = SegmentationDataset(data_path,image_names[:1100],feature_extractor) ## Taking 100 Images for training for now
test_dataset =  SegmentationDataset(data_path,image_names[1100:1200],feature_extractor)

In [17]:
test_dataloader = DataLoader(test_dataset,batch_size=16,shuffle=True)

In [18]:
train_dataloader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)

In [19]:
len(test_dataloader)

7

Finetuning

In [20]:
from datasets import load_metric

metric = load_metric("mean_iou")
#metric.compute()

Downloading builder script:   0%|          | 0.00/3.14k [00:00<?, ?B/s]

In [21]:
path = "/kaggle/working/Hamatsu_LG_Binary.pt"
path1 = "/kaggle/working/Hamatsu_LG_Binary_loss.pt"

In [22]:
import torch
from torch import nn
#from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm


# define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.00006)
# move model to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("done")

done


In [23]:
maxx = 100
model.train()
for epoch in range(50):  # loop over the dataset multiple times
    metric = load_metric("mean_iou")
    
    print("Epoch:", epoch)
    for idx, batch in enumerate(tqdm(train_dataloader)):
        # get the inputs;
        pixel_values = batch["pixel_values"].cuda()
        labels = batch["labels"].cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(pixel_values=pixel_values, labels=labels)
        loss, logits = outputs.loss, outputs.logits
        
        loss.backward()
        optimizer.step()

        # evaluate
        
    with torch.no_grad():
        
        for idx, batch in enumerate(tqdm(test_dataloader)):
    # get the inputs;
            pixel_values1 = batch["pixel_values"].cuda()
            labels1 = batch["labels"].cuda()
            outputs1 = model(pixel_values=pixel_values1, labels=labels1)
            loss, logits = outputs1.loss, outputs1.logits
            upsampled_logits = nn.functional.interpolate(logits, size=labels1.shape[-2:], mode="bilinear", align_corners=False)
            predicted = upsampled_logits.argmax(dim=1)
            metric.add_batch(predictions=predicted.detach().cpu().numpy(), references=labels1.detach().cpu().numpy())
        
        # note that the metric expects predictions + labels as numpy arrays

        # let's print loss and metrics every 100 batches
        #predictions=predicted.detach().cpu().numpy()

        #metrics = metric._compute(predictions=predicted.detach().cpu().numpy(), references=labels1.detach().cpu().numpy(),num_labels=len(id2label), 
                               #ignore_index=255,
                               #reduce_labels=False, # we've already reduced the labels before)
                                # )
        metrics = metric.compute(num_labels=len(id2label), ignore_index=255, reduce_labels=False)
        ars = loss.item()
        print("Loss:", ars)
        print("Mean_iou:", metrics["mean_iou"])
        print("Mean accuracy:", metrics["mean_accuracy"])

        if maxx > ars:
            maxx = ars
            torch.save(model, path1)

Epoch: 0


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  acc = total_area_intersect / total_area_label


Loss: 0.20233193039894104
Mean_iou: 0.4989256654653912
Mean accuracy: 0.9978513309307824
Epoch: 1


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.149998739361763
Mean_iou: 0.49863250907346596
Mean accuracy: 0.9972650181469319
Epoch: 2


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.16119346022605896
Mean_iou: 0.49941132197491084
Mean accuracy: 0.9988226439498217
Epoch: 3


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.06032943353056908
Mean_iou: 0.49995980997346695
Mean accuracy: 0.9999196199469339
Epoch: 4


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.05985676497220993
Mean_iou: 0.4999999427764216
Mean accuracy: 0.9999998855528432
Epoch: 5


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0898103266954422
Mean_iou: 0.4999334871274225
Mean accuracy: 0.999866974254845
Epoch: 6


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.07541331648826599
Mean_iou: 0.4999962232438284
Mean accuracy: 0.9999924464876568
Epoch: 7


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.028996778652071953
Mean_iou: 0.4998949565846618
Mean accuracy: 0.9997899131693236
Epoch: 8


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  iou = total_area_intersect / total_area_union


Loss: 0.04666813090443611
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 9


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.02971479296684265
Mean_iou: 0.49999755846065674
Mean accuracy: 0.9999951169213135
Epoch: 10


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.01442214660346508
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 11


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.014261017553508282
Mean_iou: 0.49997941858631745
Mean accuracy: 0.9999588371726349
Epoch: 12


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.021593963727355003
Mean_iou: 0.4999896425323173
Mean accuracy: 0.9999792850646346
Epoch: 13


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.01314836647361517
Mean_iou: 0.4999771677922354
Mean accuracy: 0.9999543355844708
Epoch: 14


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.03231147304177284
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 15


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.014755403622984886
Mean_iou: 0.4999960134240411
Mean accuracy: 0.9999920268480822
Epoch: 16


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.022936346009373665
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 17


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.013611406087875366
Mean_iou: 0.4999999427764216
Mean accuracy: 0.9999998855528432
Epoch: 18


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.009294810704886913
Mean_iou: 0.4999994086896903
Mean accuracy: 0.9999988173793806
Epoch: 19


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.00987278576940298
Mean_iou: 0.4999999427764216
Mean accuracy: 0.9999998855528432
Epoch: 20


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.015082846395671368
Mean_iou: 0.4999997520311604
Mean accuracy: 0.9999995040623209
Epoch: 21


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.006107619032263756
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 22


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.006582650355994701
Mean_iou: 0.49999834051622766
Mean accuracy: 0.9999966810324553
Epoch: 23


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.005433360114693642
Mean_iou: 0.4999999427764216
Mean accuracy: 0.9999998855528432
Epoch: 24


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.008167870342731476
Mean_iou: 0.49999925609348134
Mean accuracy: 0.9999985121869627
Epoch: 25


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.00555654289200902
Mean_iou: 0.49999877923032837
Mean accuracy: 0.9999975584606567
Epoch: 26


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.009506937116384506
Mean_iou: 0.4999980353238097
Mean accuracy: 0.9999960706476194
Epoch: 27


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.009369614534080029
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 28


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.002839779481291771
Mean_iou: 0.49999933239158584
Mean accuracy: 0.9999986647831717
Epoch: 29


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.024303626269102097
Mean_iou: 0.49999982832926493
Mean accuracy: 0.9999996566585299
Epoch: 30


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0041750529780983925
Mean_iou: 0.4999991797953769
Mean accuracy: 0.9999983595907538
Epoch: 31


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0031025779899209738
Mean_iou: 0.49999675733055976
Mean accuracy: 0.9999935146611195
Epoch: 32


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.004118818324059248
Mean_iou: 0.4999999237018955
Mean accuracy: 0.999999847403791
Epoch: 33


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.005314809735864401
Mean_iou: 0.49999935146611196
Mean accuracy: 0.9999987029322239
Epoch: 34


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0015619602054357529
Mean_iou: 0.4999999809254739
Mean accuracy: 0.9999999618509477
Epoch: 35


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0015436750836670399
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 36


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.002463812241330743
Mean_iou: 0.49999990462736943
Mean accuracy: 0.9999998092547389
Epoch: 37


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0035785993095487356
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 38


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0021618723403662443
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 39


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0008783847442828119
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 40


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0028031212277710438
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 41


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.002503001596778631
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 42


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.01026538573205471
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 43


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0017901876708492637
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 44


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.008201174437999725
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 45


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0006493509863503277
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 46


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0014147296315059066
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 47


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0009428558987565339
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 48


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0005849132430739701
Mean_iou: 1.0
Mean accuracy: 1.0
Epoch: 49


  0%|          | 0/69 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Loss: 0.0015961485914885998
Mean_iou: 1.0
Mean accuracy: 1.0


In [24]:
x=w

NameError: name 'w' is not defined