In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
from PIL import Image
import torch
from torchvision import transforms
import torchvision.models as m
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
base_dir="/content/drive/MyDrive/Dataset_BUSI_with_GT"
class_names=[]
for path in os.listdir(base_dir):
  class_names.append(path)

print(class_names)

['benign', 'normal', 'malignant']


In [None]:
batch_size=16
hidden_size=256
epoch_no=50


In [None]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self, base_dir, class_names, transform=None,max_samples_per_class=None):
    self.base_dir = base_dir
    self.class_names = class_names
    self.transform = transform
    self.image_paths = []
    self.mask_paths = []
    self.labels = []
    self.max_samples_per_class = max_samples_per_class


  def load_data(self):
    class_samples = {class_name: 0 for class_name in self.class_names}
    for class_name in self.class_names:
        class_dir = os.path.join(self.base_dir, class_name)
        image_files = [f for f in os.listdir(class_dir) if f.endswith('.png') and '_mask' not in f]
        for image_file in image_files:
            if class_samples[class_name] < self.max_samples_per_class:
                mask_file = f"{os.path.splitext(image_file)[0]}_mask.png"
                if os.path.exists(os.path.join(class_dir, mask_file)):
                    self.image_paths.append(os.path.join(class_dir, image_file))
                    self.mask_paths.append(os.path.join(class_dir, mask_file))
                    label = self.class_names.index(class_name)
                    self.labels.append(label)
                    class_samples[class_name] += 1
  def get_data(self):
    return self.image_paths,self.mask_paths,self.labels

  def __len__(self):
    return len(self.image_paths)

  def __getitem__(self, idx):
        image_paths = self.image_paths[idx]
        mask_paths = self.mask_paths[idx]
        label=self.labels[idx]
        image = cv2.imread(image_paths)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_paths, cv2.IMREAD_GRAYSCALE)

        if self.transform:
            image = self.transform(image)
            mask = mask_transform(mask)

        label_tensor=torch.tensor(label)

        return image, mask ,label_tensor

In [None]:
transform=transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224]),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
mask_transform = transforms.Compose([transforms.ToTensor(),transforms.Resize([224,224])])

dataset=Dataset(base_dir,class_names=class_names,transform=transform,max_samples_per_class=300)
dataset.load_data()
data_loader=torch.utils.data.DataLoader(dataset,batch_size=16,shuffle=False)


In [None]:
dataset_len=len(dataset)
train_size = int(0.8 * dataset_len)
val_size = dataset_len - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=16,shuffle=False)
val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=16,shuffle=False)

In [None]:
from torchvision.models.vision_transformer import VisionTransformer
from torchvision.models import vit_b_16, ViT_B_16_Weights

In [None]:
vit_model = m.vit_b_16(weights=m.ViT_B_16_Weights.DEFAULT)
for name, module in vit_model.named_modules():
    print(f"{name}: {module.__class__.__name__}")

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:02<00:00, 141MB/s]


: VisionTransformer
conv_proj: Conv2d
encoder: Encoder
encoder.dropout: Dropout
encoder.layers: Sequential
encoder.layers.encoder_layer_0: EncoderBlock
encoder.layers.encoder_layer_0.ln_1: LayerNorm
encoder.layers.encoder_layer_0.self_attention: MultiheadAttention
encoder.layers.encoder_layer_0.self_attention.out_proj: NonDynamicallyQuantizableLinear
encoder.layers.encoder_layer_0.dropout: Dropout
encoder.layers.encoder_layer_0.ln_2: LayerNorm
encoder.layers.encoder_layer_0.mlp: MLPBlock
encoder.layers.encoder_layer_0.mlp.0: Linear
encoder.layers.encoder_layer_0.mlp.1: GELU
encoder.layers.encoder_layer_0.mlp.2: Dropout
encoder.layers.encoder_layer_0.mlp.3: Linear
encoder.layers.encoder_layer_0.mlp.4: Dropout
encoder.layers.encoder_layer_1: EncoderBlock
encoder.layers.encoder_layer_1.ln_1: LayerNorm
encoder.layers.encoder_layer_1.self_attention: MultiheadAttention
encoder.layers.encoder_layer_1.self_attention.out_proj: NonDynamicallyQuantizableLinear
encoder.layers.encoder_layer_1.dropo

In [None]:
attr_path = ['encoder', 'layers', 'encoder_layer_0', 'self_attention']
module = vit_model
for attr in attr_path:
    module = getattr(module, attr)
self_attn_module = module

In [None]:
print(self_attn_module.__class__.__name__)
print(hasattr(self_attn_module, 'in_proj_weight'))

MultiheadAttention
True


In [None]:
class MaskGuidedAttention(nn.Module):
    def __init__(self, in_features, out_features):
        super(MaskGuidedAttention, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / np.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, image, mask):
        mask = mask.unsqueeze(1)  # [16, 1, 224, 224] -> [16, 1, 1, 224, 224]
        attention_weights = torch.softmax(image, dim=1)
        mask = mask.broadcast_to(attention_weights.shape)
        masked_attention_weights = attention_weights * mask
        output = F.linear(masked_attention_weights, self.weight, self.bias)
        return output

In [None]:
in_features = self_attn_module.out_proj.weight.shape[1]
out_features = self_attn_module.out_proj.weight.shape[0]
self_attn_module.out_proj = MaskGuidedAttention(in_features, out_features)

In [None]:
for name, module in vit_model.named_modules():
    if name != 'encoder.layer.0.SelfAttention':  # exclude the modified attention layer
        for param in module.parameters():
            param.requires_grad = False


In [None]:
class TumorClassifier(nn.Module):
    def __init__(self, num_classes, hidden_size=256):
        super(TumorClassifier, self).__init__()
        self.vit_model = vit_model
        self.fc = nn.Linear(in_features=1000, out_features=num_classes)
        self.fc.requires_grad=True
        self.Softmax=nn.Softmax(dim=1)

    def forward(self, image):
        image = self.vit_model(image)
        return self.Softmax(self.fc(image))


In [None]:
import matplotlib.pyplot as plt

def visualize_attention_maps(image, output):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(image.permute(1, 2, 0).cpu().numpy())
    ax[0].set_title('Input Image')
    ax[1].imshow(output.permute(1, 2, 0).cpu().numpy())
    ax[1].set_title('Attention Map')
    plt.show()

In [None]:
classifier=TumorClassifier(num_classes=len(class_names))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    classifier.to(device)
    print('GPU!!!')
else:
    device = torch.device("cpu")
    print('CPU!!')

GPU!!!


In [None]:
train_losses = []
validation_accuracies = []
validation_losses = []

In [None]:
for epoch in range(60):
  for batch in train_loader:
    image ,mask, label_tensor= batch
    image ,mask, label_tensor= image.to(device), mask.to(device), label_tensor.to(device)
    optimizer.zero_grad()
    op= classifier(image)
    loss=criterion(op,label_tensor)
    train_losses.append(loss.item())
    loss.backward()
    optimizer.step()

  print(f'Epoch{epoch+1},Loss{loss.item()}')

Epoch1,Loss1.0155643224716187
Epoch2,Loss0.8478308916091919
Epoch3,Loss0.681298017501831
Epoch4,Loss0.6381515264511108
Epoch5,Loss0.6167305707931519
Epoch6,Loss0.6024188995361328
Epoch7,Loss0.5928696393966675
Epoch8,Loss0.5857340097427368
Epoch9,Loss0.5798603296279907
Epoch10,Loss0.5748457908630371
Epoch11,Loss0.5706652402877808
Epoch12,Loss0.5672688484191895
Epoch13,Loss0.5645654201507568
Epoch14,Loss0.562443733215332
Epoch15,Loss0.5607826709747314
Epoch16,Loss0.5594663023948669
Epoch17,Loss0.5584028363227844
Epoch18,Loss0.5575292110443115
Epoch19,Loss0.5568017363548279
Epoch20,Loss0.5561859607696533
Epoch21,Loss0.5556532144546509
Epoch22,Loss0.5551806092262268
Epoch23,Loss0.5547520518302917
Epoch24,Loss0.5543571710586548
Epoch25,Loss0.5539911985397339
Epoch26,Loss0.5536569952964783
Epoch27,Loss0.5533654689788818
Epoch28,Loss0.5531238913536072
Epoch29,Loss0.5529271364212036
Epoch30,Loss0.5527607202529907
Epoch31,Loss0.5526075959205627
Epoch32,Loss0.5524567365646362
Epoch33,Loss0.55231

In [None]:
from sklearn.metrics import accuracy_score,confusion_matrix,roc_curve,auc
from sklearn.metrics import precision_recall_curve,classification_report
from sklearn.preprocessing import label_binarize

def validate(classifier, val_loader, device):

  classifier.eval()

  y_true = []
  y_pred = []
  total_correct = 0
  total_loss=0

  with torch.no_grad():
    for batch in val_loader:
      images, mask, labels = batch
      images, mask, labels = images.to(device), mask.to(device), labels.to(device)

      op_v = classifier(images)
      _, predicted = torch.max(op_v, 1)

      total_correct += (predicted == labels).sum().item()
      y_true.extend(labels.cpu().numpy())
      y_pred.extend(predicted.cpu().numpy())



  return y_true, y_pred, total_correct


y_true, y_pred, total_correct= validate(classifier, val_loader, device)

visualize_attention_maps(image,output)



accuracy = total_correct / len(y_true)
print(f"Validation Accuracy: {accuracy:.4f}")

confusion_matrix = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:\n", confusion_matrix)

plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix, annot=True, cmap="Blues",xticklabels=class_names,yticklabels=class_names)
plt.xlabel("Predicted labels")
plt.ylabel("True labels")
plt.title("Confusion Matrix")
plt.show()

class_labels = range(len(np.unique(y_true)))

report = classification_report(y_true, y_pred, target_names=class_labels, output_dict=True)
print("\n Classification Report:")



print(f"Overall Accuracy: {report['accuracy']:.4f}")

class_report_dict = {k: v for k, v in report.items() if k not in ['accuracy' ,'macro avg','weighted avg']}



class_colors = {'Benign': 'blue', 'Malignant': 'red', 'Normal': 'green'}
class_colors = {0: 'blue', 1: 'red', 2: 'green'}

plt.figure(figsize=(8, 6))
for i, (class_name, metrics_dict) in enumerate(class_report_dict.items()):
    plt.plot(range(3), [metrics_dict['precision'], metrics_dict['recall'], metrics_dict['f1-score']],
                label=class_names, marker='o', color=class_colors[i])
    for j, value in enumerate([metrics_dict['precision'], metrics_dict['recall'], metrics_dict['f1-score']]):
        plt.text(j, value, f"{value:.2f}", ha="center", va="bottom", fontsize=10)
plt.xlabel('Metric')
plt.ylabel('Value')
plt.title('Metrics for each class')
plt.xticks(range(3), ['Precision', 'Recall', 'F1-Score'])  # Corrected xticks range
plt.ylim([0.7,1])

# Create a custom legend with colors and labels
cls = {0:'Benign',1:'Normal',2:'Malignant'}
legend_handles = [plt.Line2D([0], [0], marker='o', color='w', label=list(cls.values())[i],
                             markerfacecolor=class_colors[i], markersize=10) for i in range(len(cls))]
plt.legend(handles=legend_handles, loc='upper right', bbox_to_anchor=(1.05, 1))
plt.show()