<a href="https://colab.research.google.com/github/ChloeZhou1997/BreastCancerCNN/blob/main/Breast_Cancer_Classfication_VGG_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pip install pylibjpeg==1.1.1
!pip install pydicom==2.1.1
!pip install torchmetrics
!pip install pytorch_lightning
# !pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from numpy import random

In [3]:
normalizer = np.load('/content/drive/MyDrive/Project/normalizer.npy')
mean = normalizer[0]
std = normalizer[1]

#Data Loading

In [4]:
import torch
import torchvision
from torchvision import transforms
import pytorch_lightning as pl
# from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tqdm.notebook import tqdm
import numpy as np

In [5]:
def load_file(path):
  return np.load(path).astype(np.float32)

In [6]:
train_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean,std),
                                       ]
)

val_transforms = transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean,std),                                     
])

In [7]:
train_dataset = torchvision.datasets.DatasetFolder("/content/drive/MyDrive/Data/processed/training",
                                                   loader = load_file,
                                                   extensions = "npy",
                                                   transform = train_transforms)

val_dataset = torchvision.datasets.DatasetFolder("/content/drive/MyDrive/Data/processed/validation",
                                                   loader = load_file,
                                                   extensions = "npy",
                                                   transform = train_transforms)

In [8]:
batch_size = 64

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle = False)

# Model Creation

In [9]:
# torchvision.models.vgg19();

In [10]:
import torchmetrics

class CheastCancer_VGG(pl.LightningModule):

  def __init__(self,init_weights=True):
    super().__init__()

    self.model = torchvision.models.vgg19()
    self.model.features[0] = torch.nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    self.model.classifier[6] = torch.nn.Linear(in_features=4096, out_features=1, bias=True)
    
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr = 1e-3)

    #loss function
    self.criterion = torch.nn.BCEWithLogitsLoss()

    #eval matrix
    self.train_acc = torchmetrics.Accuracy()
    self.val_acc = torchmetrics.Accuracy()

    #list to store loss curve and accuracy curve
    self.traincc, self.valacc = [],[]
    self.trainloss, self.valloss = [],[0]

  def forward(self, data):
    pred = self.model(data)
    return pred

  def training_step(self, batch, batch_idx):
    img, label = batch
    pred = self(img)
    # print(pred)
    # label = label.float()
    loss = self.criterion(pred[:,0],label.float())
    acc = self.train_acc(pred[:,0],label)

    print(f"Training Step : The Loss is {loss} and the accuracy is {acc}")

    self.log("Train Loss", loss)
    self.log("Step Train ACC", self.train_acc(torch.sigmoid(pred[:,0]), label.int()))

    return loss
  
  def training_epoch_end(self, outs):
    self.log("Train ACC", self.train_acc.compute())

  def validation_step(self, batch, batch_idx):
    img, label = batch
    # label = label.float()
    pred = self(img)
    # print(img)
    # print(pred)
    # loss = self.loss_fn(pred,label)
    loss = self.criterion(pred[:,0],label.float())
    acc = self.val_acc(torch.sigmoid(pred[:,0]), label.int())

    print(f"Validation Step : The Loss is {loss} and the accuracy is {acc}")

    self.log("Train Loss", loss)
    self.log("Step Train ACC", acc)


  def validation_epoch_end(self, outs):
    self.log("Val ACC", self.val_acc.compute())

  def configure_optimizers(self):
      return [self.optimizer]

#Model Trainig

In [11]:
def evaluation(model, dataloader):
  predictions, labels = [], []
  with torch.no_grad():
    for batch in dataloader:
      batch = [term.cuda() for term in batch]
      pred = model(batch[0]).squeeze(1)
      pred = torch.sigmoid(pred)
      pred = [1 if p >= 0.5 else 0 for p in pred]
      predictions += pred
      labels += batch[1]
  
  result = 0
  for p, l in zip(predictions, labels):

    if p == l:
      result += 1
    else:
      continue
  accuracy = result / len(labels)
  return accuracy

In [12]:
import tqdm
best_accuracy = float('-inf')
# state_dict = torch.load('file_path')
# model.load_state_dict(state_dict)
step = 0
patinet = 0
stop_train = False
model = CheastCancer_VGG()
model = model.cuda()
for epoch in range(50):
  bar = tqdm.trange(len(train_loader))
  total_loss = 0
  for epoch_step, batch in zip(bar, train_loader):
    model.optimizer.zero_grad()
    model.train()
    batch = [term.cuda() for term in batch]

    pred = model(batch[0])
    loss = model.criterion(pred.squeeze(1), batch[1].float())

    loss.backward()
    total_loss += loss.item()
    model.optimizer.step()
    bar.set_postfix(avg_loss='{}'.format(total_loss / (epoch_step+1)))

    if step % 500 == 0 and step != 0:
      model.eval()
      val_accu = evaluation(model, val_loader)
      print('Val Accuracy: {}'.format(val_accu))

      if val_accu >= best_accuracy:
        torch.save(model.state_dict(), '/content/drive/MyDrive/Project/checkpoint/VGG/checkpoint_ACC_{}.ckpt'.format(val_accu))
        patient = 0
      else:
        patient += 1
      
      if patient == 5:
        stop_train = True
        break
    step += 1
  if stop_train:
    print('Stop Training !')
    break

100%|██████████| 299/299 [03:48<00:00,  1.31it/s, avg_loss=1.722931771772761]
 67%|██████▋   | 201/299 [02:27<01:11,  1.37it/s, avg_loss=0.6960051068575075]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6950193631609148]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6927209275223339]
 34%|███▍      | 103/299 [01:15<02:22,  1.38it/s, avg_loss=0.6921390707676227]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6926266269939002]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6927838156055846]
  2%|▏         | 5/299 [00:04<03:30,  1.40it/s, avg_loss=0.6932323575019836]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:37<00:00,  1.08it/s, avg_loss=0.6926663392363583]
 69%|██████▉   | 206/299 [02:30<01:07,  1.38it/s, avg_loss=0.6926312328536729]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:40<00:00,  1.07it/s, avg_loss=0.6926569793136622]
100%|██████████| 299/299 [03:37<00:00,  1.38it/s, avg_loss=0.6925738208668687]
 36%|███▌      | 108/299 [01:19<02:19,  1.37it/s, avg_loss=0.6924405289352487]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6926407568829515]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925849370334459]
  3%|▎         | 10/299 [00:07<03:30,  1.38it/s, avg_loss=0.6926053437319669]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6926560465707428]
 71%|███████   | 211/299 [02:34<01:03,  1.38it/s, avg_loss=0.6927320333021991]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:38<00:00,  1.07it/s, avg_loss=0.6926235419053298]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6926022030438069]
 38%|███▊      | 113/299 [01:22<02:14,  1.38it/s, avg_loss=0.6928002902290278]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:38<00:00,  1.07it/s, avg_loss=0.692559001637143]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925615447022045]
  5%|▌         | 15/299 [00:11<03:26,  1.38it/s, avg_loss=0.6924264691770077]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:37<00:00,  1.08it/s, avg_loss=0.6926230758328901]
 72%|███████▏  | 216/299 [02:37<00:59,  1.38it/s, avg_loss=0.6928089296762845]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6926096821309731]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925306302249232]
 39%|███▉      | 118/299 [01:26<02:10,  1.38it/s, avg_loss=0.6928384789899618]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6925414706951001]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925727625355673]
  7%|▋         | 20/299 [00:15<03:21,  1.39it/s, avg_loss=0.6923218511399769]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:38<00:00,  1.07it/s, avg_loss=0.6926174060158108]
 74%|███████▍  | 221/299 [02:41<00:56,  1.38it/s, avg_loss=0.6925038429530891]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:37<00:00,  1.08it/s, avg_loss=0.6926101509942657]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925950654374318]
 41%|████      | 123/299 [01:29<02:08,  1.37it/s, avg_loss=0.6925328282579299]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:37<00:00,  1.08it/s, avg_loss=0.6926327060696273]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.692557278883497]
  8%|▊         | 25/299 [00:19<03:19,  1.37it/s, avg_loss=0.6922303827909323]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:37<00:00,  1.08it/s, avg_loss=0.6925238761614796]
 76%|███████▌  | 226/299 [02:44<00:53,  1.37it/s, avg_loss=0.6925265586848827]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:38<00:00,  1.07it/s, avg_loss=0.6926034936139416]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6926396425352448]
 43%|████▎     | 128/299 [01:33<02:03,  1.39it/s, avg_loss=0.6929331994795984]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:39<00:00,  1.07it/s, avg_loss=0.6925391054472397]
100%|██████████| 299/299 [03:36<00:00,  1.38it/s, avg_loss=0.6925985322748139]
 10%|█         | 30/299 [00:22<03:14,  1.38it/s, avg_loss=0.6934630447818387]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:38<00:00,  1.08it/s, avg_loss=0.6925785023233165]
 77%|███████▋  | 231/299 [02:46<00:48,  1.39it/s, avg_loss=0.6925590130789526]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:34<00:00,  1.09it/s, avg_loss=0.6926088341103749]
100%|██████████| 299/299 [03:34<00:00,  1.39it/s, avg_loss=0.6926335241882299]
 44%|████▍     | 133/299 [01:36<01:59,  1.39it/s, avg_loss=0.6928294492301657]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:35<00:00,  1.09it/s, avg_loss=0.6925866043687265]
100%|██████████| 299/299 [03:34<00:00,  1.39it/s, avg_loss=0.6925848788242276]
 12%|█▏        | 35/299 [00:25<03:09,  1.39it/s, avg_loss=0.6918298800786337]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:35<00:00,  1.09it/s, avg_loss=0.6925415029892554]
 79%|███████▉  | 236/299 [02:50<00:45,  1.39it/s, avg_loss=0.6922987349928683]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:35<00:00,  1.09it/s, avg_loss=0.6925894584145434]
100%|██████████| 299/299 [03:34<00:00,  1.39it/s, avg_loss=0.6925570456479305]
 46%|████▌     | 138/299 [01:39<01:53,  1.42it/s, avg_loss=0.692280158293333] 

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:31<00:00,  1.10it/s, avg_loss=0.692557929949617]
100%|██████████| 299/299 [03:32<00:00,  1.41it/s, avg_loss=0.6925798395405645]
 13%|█▎        | 40/299 [00:29<03:03,  1.41it/s, avg_loss=0.6921663342452631]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:31<00:00,  1.10it/s, avg_loss=0.6925443713880303]
 81%|████████  | 241/299 [02:53<00:41,  1.40it/s, avg_loss=0.692482054233551] 

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:33<00:00,  1.09it/s, avg_loss=0.6925799994165682]
100%|██████████| 299/299 [03:34<00:00,  1.39it/s, avg_loss=0.6925679853927331]
 48%|████▊     | 143/299 [01:43<01:52,  1.39it/s, avg_loss=0.69240885724624]  

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:33<00:00,  1.09it/s, avg_loss=0.6926148198918755]
100%|██████████| 299/299 [03:35<00:00,  1.39it/s, avg_loss=0.6925696477443478]
 15%|█▌        | 45/299 [00:33<03:01,  1.40it/s, avg_loss=0.6927265654439512]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:34<00:00,  1.09it/s, avg_loss=0.6925910986386813]
 82%|████████▏ | 246/299 [02:56<00:38,  1.38it/s, avg_loss=0.6925854967673298]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:34<00:00,  1.09it/s, avg_loss=0.6925920764339408]
100%|██████████| 299/299 [03:34<00:00,  1.40it/s, avg_loss=0.6925737455138394]
 49%|████▉     | 148/299 [01:46<01:46,  1.41it/s, avg_loss=0.6927562968042873]

Val Accuracy: 0.5174499767333643


100%|██████████| 299/299 [04:34<00:00,  1.09it/s, avg_loss=0.6925514765009034]
100%|██████████| 299/299 [03:34<00:00,  1.39it/s, avg_loss=0.6924987333674096]


In [13]:
# model2 = CheastCancer_VGG()

In [14]:
# checkpoint_callback2 = ModelCheckpoint(
#     dirpath = "/content/drive/MyDrive/Project/checkpoint/VGG",
#     filename="sample-breastcancer-{epoch:02d}-{Val ACC:.2f}",
#     monitor = "Val ACC",
#     save_top_k = 10,
#     mode ="max")

In [15]:
# gpus = 1
# trainer2 = pl.Trainer(gpus = gpus, logger = TensorBoardLogger(save_dir = "/content/drive/MyDrive/Project/logs/VGG"), log_every_n_steps = 1,
#                      callbacks = checkpoint_callback2, max_epochs = 600)

In [16]:
# trainer2.fit(model2, train_loader, val_loader)