In [19]:
# If you run across any errors during image processing, uncomment and run this part.

# !conda install -c conda-forge pillow -y
# !conda install -c conda-forge pydicom -y
# !conda install -c conda-forge gdcm -y
# !pip install pylibjpeg pylibjpeg-libjpeg
# !pip install pylibjpeg pylibjpeg-openjpeg




import numpy as np 
import torchvision
import pandas as pd 
import torch
from torch.utils import data
import pydicom 
from torch import nn
import pylibjpeg
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from torchmetrics.classification import BinaryAccuracy

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

import os

# Data available at https://www.kaggle.com/competitions/rsna-breast-cancer-detection/data

In [20]:
train_csv = pd.read_csv('/kaggle/input/rsna-breast-cancer-detection/train.csv')

patient_ids = train_csv['patient_id']

train_csv.index = train_csv['patient_id']

train_img_dir = '/kaggle/input/rsna-breast-cancer-detection/train_images/'

train_csv = train_csv.drop(columns={'BIRADS','density'})

train_ids = patient_ids[:9000]
val_ids = patient_ids[9000:]

train_data = train_csv.loc[train_ids]
val_data = train_csv.loc[val_ids]


train_patient_ids = list(train_data.patient_id.apply(lambda x: str(x)+'/'))
train_img_ids = list(train_data.image_id.apply(lambda x: str(x)))
train_ids = []

for patient_id, img_id in zip(train_patient_ids, train_img_ids):
    train_ids.append(patient_id+img_id)
train_data['images'] = train_ids

val_patient_ids = list(val_data.patient_id.apply(lambda x: str(x)+'/'))
val_img_ids = list(val_data.image_id.apply(lambda x: str(x)))
val_ids = []

for patient_id, img_id in zip(val_patient_ids, val_img_ids):
    val_ids.append(patient_id+img_id)
    
val_data['images'] = val_ids
    

In [21]:
train_transform = transforms.Compose([
                        transforms.RandomVerticalFlip(),
                        transforms.Resize((128, 128)),
                        transforms.ToTensor(),
                    
                ])

class Dataset(data.Dataset):
    def __init__(self, indices, direc, transform=None):
#         self.train_csv = train_csv
        self.list_IDs = indices
        self.direc = direc
        self.transform = transform
        
    def __len__(self):
        return len(self.list_IDs)

    def __getitem__(self, index):
        
        ID = self.list_IDs[index]
        x = pydicom.dcmread(self.direc + ID + ".dcm")
        x = x.pixel_array
        x = x.astype(np.float64)
        mean = np.mean(x)
        std  = np.std(x)
        x -= mean
        x /= std
        x = Image.fromarray(x)
        if self.transform != None :
            x = self.transform(x)
        patient_id = int(ID.split('/')[0])
        y = train_csv[train_csv['patient_id']==patient_id]['cancer'].unique()[0]
        return x, y

In [22]:
train_set = Dataset(train_ids, train_img_dir, train_transform)
train_loader = data.DataLoader(train_set, batch_size = 2, shuffle = True)

val_set = Dataset(val_ids, train_img_dir, train_transform)
val_loader = data.DataLoader(val_set, batch_size = 2, shuffle = True)

In [23]:
negative_patients = train_data[train_data['cancer']==0].reset_index(drop=True)
positive_patients  = train_data[train_data['cancer']==1].reset_index(drop=True)

positive_weights = len(negative_patients)/len(train_data)
negative_weights = len(positive_patients)/len(train_data)

class CustomBCE(nn.Module):
    def __init__(self, p_weights, n_weights ):
        super(CustomBCE, self).__init__()
        self.p_weights = p_weights 
        self.n_weights = n_weights
        self.epsilon   =1e-7
    
    def forward(self,y_true,y_pred):
        loss = 0.0
        for y in y_pred:
            y = torch.max(y)
            loss += -torch.mean(self.p_weights*y_true*torch.log(y+self.epsilon) + self.n_weights*(1-y_true)*torch.log(1-y+self.epsilon)) 
        return torch.Tensor(loss)
        

In [24]:
vgg = torchvision.models.vgg19(pretrained=True).to(device)

class Net(nn.Module):
    def __init__(self, vgg):
        super(Net, self).__init__()
        self.model_input = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.model_features = list(vgg.features.children())[1:]
        self.model_output = list(vgg.classifier.children())[1:-1]
        self.fc1 = nn.Linear(8192, 4096)
        self.fc2 = nn.Linear(4096, 2)
        
    def forward(self, x):
        x = self.model_input(x)
        for feature_layer in self.model_features:
            x = feature_layer(x)
        x = torch.flatten(x,1)
        x = self.fc1(x)
        for classifier_layer in self.model_output:
            x = classifier_layer(x)
        x = F.softmax(self.fc2(x))
        return x
    
model = Net(vgg)
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters())
criterion = CustomBCE(positive_weights, negative_weights)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


In [26]:
losses = []
val_losses = []

for epoch in range(50):
    # keep track of training and validation loss
    train_loss = []
    
    
    model.train()
    for x, y in train_loader:
        optimizer.zero_grad()
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
        print(train_loss)
        
    losses.append(np.mean(train_loss))
    print(f'Epoch: {epoch} Loss: {np.mean(train_loss)}')
    
    if epoch % 10 == 0:
        model.eval()
        val_loss = []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                y_hat = model(x)
                loss = criterion(y_hat, y)
                val_loss.append(loss.item())

            val_losses.append(np.mean(val_loss))
            print(f'Validation Epoch: {epoch} Loss: {np.mean(val_loss)}')




[15.750470161437988]
[15.750470161437988, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750470161437988]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750470161437988, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750470161437988, 15.750469207763672, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672]
[15.750470161437988, 15.750469207763672, 15.750469207763672, 15.750469207763672, 15.7504692

KeyboardInterrupt: 