In [1]:
import os
import cv2
import skimage.io
from tqdm.notebook import tqdm
import zipfile
import numpy as np
import pandas as pd

In [2]:
train_df = pd.read_csv('/kaggle/input/prostate-cancer-grade-assessment/train.csv')
print(train_df.shape)
train_df.head()

(10616, 4)


Unnamed: 0,image_id,data_provider,isup_grade,gleason_score
0,0005f7aaab2800f6170c399693a96917,karolinska,0,0+0
1,000920ad0b612851f8e01bcc880d9b3d,karolinska,0,0+0
2,0018ae58b01bdadc8e347995b69f99aa,radboud,4,4+4
3,001c62abd11fa4b57bf7a6c603a11bb9,karolinska,4,4+4
4,001d865e65ef5d2579c190a0e0350d8f,karolinska,0,0+0


In [3]:
TRAIN = '../input/panda-16x128x128-tiles-data/train'
MASKS = '../input/panda-16x128x128-tiles-data/masks'

In [4]:
def get_dominant_color(a):
    colors, count = np.unique(a.reshape(-1,a.shape[-1]), axis=0, return_counts=True)
    return colors[count.argmax()]

In [5]:
def getLabel(row):
    if(row[1].get(3) == "0+0"):
        return True, "0"
    if(row[1].get(3) == "3+3"):
        return True, "3"
    if(row[1].get(3) == "4+4"):
        return True, "4"
    if(row[1].get(3) == "5+5"):
        return True, "5"
    return False, "-1"
        




In [6]:
from PIL import Image
images = []
labels = []
for row,z in zip(train_df.iterrows(),tqdm(range(len(train_df)))):
    b, l = getLabel(row)
    if(b):
        score = l
        name = row[1].get(0)
        for i in range(6):
            if(not os.path.isfile(MASKS+'/'+name+"_"+str(i)+".png")):
                break
                
            dis = get_dominant_color(np.asarray(Image.open(MASKS+'/'+name+"_"+str(i)+".png")))

            if((3 in dis)|(4 in dis)|(5 in dis)):
                im = Image.open(TRAIN+'/'+name+"_"+str(i)+".png")
                images.append(im)
                labels.append(score)
            else:
                im = Image.open(TRAIN+'/'+name+"_"+str(i)+".png")
                images.append(im)
                labels.append(0)
                
        
            
        
        



HBox(children=(FloatProgress(value=0.0, max=10616.0), HTML(value='')))




In [7]:
print(len(images))
print(len(labels))

34728
34728


In [8]:
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])


In [9]:
from torch.utils.data import Dataset, DataLoader

class data(Dataset):

    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(images)
    
    def __getitem__(self, idx):
        sample = {'img': preprocess(self.images[idx]), 'label': self.labels[idx]}

        return sample



In [10]:
import torch
import torch.nn as nn

model = torch.hub.load('facebookresearch/semi-supervised-ImageNet1K-models', 'resnet18_swsl')
model.fc = nn.Linear(512, 4)
model

Downloading: "https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/archive/master.zip" to /root/.cache/torch/hub/master.zip
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




Downloading: "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnet18-118f1556.pth" to /root/.cache/torch/hub/checkpoints/semi_weakly_supervised_resnet18-118f1556.pth


HBox(children=(FloatProgress(value=0.0, max=46811375.0), HTML(value='')))




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [11]:
from torch.utils.data.sampler import SubsetRandomSampler

dataset = data(images,labels)

batch_size = 1
validation_split = .2
shuffle_dataset = True
random_seed= 42

# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]

# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                           sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                sampler=valid_sampler)

In [12]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

Training Loop

In [13]:
from tqdm import tqdm
import copy
import torch.nn.functional as F
model.to("cuda")
highest_acc = 0.00
optimizer = torch.optim.Adam(model.parameters(), lr=0.0006)
for epoch in tqdm(range(16)):
    model.train()
    total_loss = 0
    total_correct = 0
    for data,z in zip(train_loader,tqdm(range(len(train_loader)))): # Get Batch
        batch = data.get("img")
        tensor = batch.squeeze(-1).to("cuda")
        label = data.get("label")
        if(label[0] == "3"):   
            label = "1"
        if(label[0] == "4"):   
            label = "2"
        if(label[0] == "5"):   
            label = "3"
        if(label[0] == "0"):   
            label = "0"
        label = torch.tensor(int(label)).unsqueeze(0).to("cuda")
        preds = model(tensor) # Pass Batch n   
        loss = F.cross_entropy(preds, label) # Calculate Loss
        optimizer.zero_grad()
        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights
        total_loss += loss.item()
        if(preds.argmax(dim=1).item() == label.item()):
            total_correct += 1
    print("epoch", epoch, "total_correct:", total_correct, "loss:", total_loss)
    val_correct = 0
    for data in validation_loader: # Get Batch
        model.eval()
        batch = data.get("img")
        tensor = batch.squeeze(-1).to("cuda")
        label = data.get("label")
        if(label[0] == "3"):   
            label = "1"
        if(label[0] == "4"):   
            label = "2"
        if(label[0] == "5"):   
            label = "3"
        if(label[0] == "0"):   
            label = "0"
        label = torch.tensor(int(label)).unsqueeze(0).to("cuda")
        preds = model(tensor) # Pass Batch n
        if(preds.argmax(dim=1).item() == label.item()):
            val_correct += 1
    print("Validation Accuracy:", val_correct/len(validation_loader)*100,"%")  
    print("\n")

    if((val_correct/len(validation_loader)*100) > highest_acc):
          highest_acc = (val_correct/len(validation_loader)*100)
          best_model = copy.deepcopy(model)
          print("This is the best model so far, saving...")
          torch.save(model,"/kaggle/working/PandaModel.pth")
print("Final Accuracy: ", highest_acc)

  0%|          | 0/16 [00:00<?, ?it/s]
  0%|          | 0/27783 [00:00<?, ?it/s][A
  0%|          | 1/27783 [00:00<6:18:42,  1.22it/s][A
  0%|          | 6/27783 [00:00<4:28:07,  1.73it/s][A
  0%|          | 12/27783 [00:01<3:10:18,  2.43it/s][A
  0%|          | 18/27783 [00:01<2:15:52,  3.41it/s][A
  0%|          | 24/27783 [00:01<1:37:47,  4.73it/s][A
  0%|          | 30/27783 [00:01<1:11:25,  6.48it/s][A
  0%|          | 35/27783 [00:01<54:29,  8.49it/s]  [A
  0%|          | 39/27783 [00:01<42:20, 10.92it/s][A
  0%|          | 43/27783 [00:01<33:31, 13.79it/s][A
  0%|          | 47/27783 [00:01<27:13, 16.98it/s][A
  0%|          | 53/27783 [00:02<21:43, 21.27it/s][A
  0%|          | 58/27783 [00:02<18:09, 25.46it/s][A
  0%|          | 63/27783 [00:02<15:32, 29.72it/s][A
  0%|          | 69/27783 [00:02<13:38, 33.86it/s][A
  0%|          | 74/27783 [00:02<12:21, 37.34it/s][A
  0%|          | 80/27783 [00:02<11:20, 40.69it/s][A
  0%|          | 86/27783 [00:02<10:38, 

epoch 0 total_correct: 26445 loss: 7248.7012271713465


  6%|▋         | 1/16 [10:31<2:37:46, 631.11s/it]
  0%|          | 0/27783 [00:00<?, ?it/s][A

Validation Accuracy: 94.81641468682506 %


This is the best model so far, saving...



  0%|          | 6/27783 [00:00<08:51, 52.22it/s][A
  0%|          | 12/27783 [00:00<08:46, 52.70it/s][A
  0%|          | 18/27783 [00:00<08:44, 52.91it/s][A
  0%|          | 24/27783 [00:00<08:45, 52.87it/s][A
  0%|          | 30/27783 [00:00<08:43, 53.06it/s][A
  0%|          | 36/27783 [00:00<08:41, 53.19it/s][A
  0%|          | 42/27783 [00:00<08:38, 53.45it/s][A
  0%|          | 48/27783 [00:00<08:35, 53.83it/s][A
  0%|          | 54/27783 [00:01<08:36, 53.69it/s][A
  0%|          | 60/27783 [00:01<08:46, 52.61it/s][A
  0%|          | 66/27783 [00:01<08:45, 52.71it/s][A
  0%|          | 72/27783 [00:01<08:43, 52.89it/s][A
  0%|          | 78/27783 [00:01<08:42, 53.03it/s][A
  0%|          | 84/27783 [00:01<08:42, 52.98it/s][A
  0%|          | 90/27783 [00:01<08:44, 52.78it/s][A
  0%|          | 96/27783 [00:01<08:46, 52.59it/s][A
  0%|          | 102/27783 [00:01<08:45, 52.63it/s][A
  0%|          | 108/27783 [00:02<08:42, 52.94it/s][A
  0%|          | 114/27783