In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import os
import pandas as pd
from torchvision.io import read_image
from torchvision.io import ImageReadMode
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoModel
from PIL import Image
import requests
from torch.utils.data import DataLoader
from medmnist import PathMNIST, INFO
from QuantumSelfAttentionLayer import QuantumSelfAttentionLayer



In [2]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [3]:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
model = AutoModel.from_pretrained('facebook/dinov2-base').to(device)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [4]:
class DiseaseDataset(Dataset):
    def __init__(self, newDatasetInput, newDatasetOutput):
        self.newDatasetInput = torch.tensor(np.array(newDatasetInput)).to(device)
        self.newDatasetOutput = torch.tensor(np.array(newDatasetOutput), dtype=torch.int64).to(device)

    def __len__(self):
        return self.newDatasetInput.shape[0]

    def __getitem__(self, idx):
        embedding = self.newDatasetInput[idx]
        label = self.newDatasetOutput[idx]

        return embedding, label

In [5]:
# Download dataset info
info = INFO['pathmnist']
label_names = info['label']

['adipose',
 'background',
 'debris',
 'lymphocytes',
 'mucus',
 'smooth muscle',
 'normal colon mucosa',
 'cancer-associated stroma',
 'colorectal adenocarcinoma epithelium']

In [7]:
def load_disease_dataset(save_dir):
    input_path = os.path.join(save_dir, "newDatasetInput.pt")
    output_path = os.path.join(save_dir, "newDatasetOutput.pt")

    inputs = torch.load(input_path)
    outputs = torch.load(output_path)
    
    return DiseaseDataset(inputs, outputs)

In [8]:
train_dataset = load_disease_dataset("train")

  self.newDatasetInput = torch.tensor(np.array(newDatasetInput)).to(device)
  self.newDatasetOutput = torch.tensor(np.array(newDatasetOutput), dtype=torch.int64).to(device)


In [9]:
test_dataset = load_disease_dataset("test")

  self.newDatasetInput = torch.tensor(np.array(newDatasetInput)).to(device)
  self.newDatasetOutput = torch.tensor(np.array(newDatasetOutput), dtype=torch.int64).to(device)


In [10]:

class DiseaseClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_gelu_stack1 = nn.Sequential(
            nn.Linear(768, 512),
            nn.GELU(),
        )

        self.quantum = nn.Sequential(
            QuantumSelfAttentionLayer(512, 4, 5),
            nn.Linear(5, 512),
        )


        self.linear_gelu_stack2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.GELU(),
            nn.Linear(512, 9),
        )

    def forward(self, x):
        x = self.linear_gelu_stack1(x)
        q = self.quantum(x)
        x = x + q
        x = self.linear_gelu_stack2(x)
        return x


In [11]:
class DiseaseClassifierFull(nn.Module):
    def __init__(self, dino, processor, classifier):
        super().__init__()
        self.dino = dino
        self.classifier = classifier
        self.processor = processor

    def forward(self, x):
        inputs = self.processor(x, return_tensors="pt").to(device)
        outputs = self.dino(**inputs)
        values = self.classifier(torch.mean(outputs.last_hidden_state, axis=1))
        
        return value

In [12]:
classifier = DiseaseClassifier()

In [13]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# trainDataloader = DataLoader(trainingData, batch_size=batch_size, shuffle=True)
# testDataloader = DataLoader(testData, batch_size=batch_size, shuffle=True)



epochs = 1000


In [14]:
batch_size = 1

In [15]:
def train_loop(dataloader, model, loss_fn, optimizer, window_size=100):
    size = len(dataloader.dataset)
    model.train()
    
    # For tracking rolling averages
    batch_losses = []
    rolling_avg_losses = []
    
    for batch, (X, y) in enumerate(dataloader):
        # Move data to device
        X = X.to(device)
        y = y.to(device)
        
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # Track loss for rolling average
        current_loss = loss.item()
        batch_losses.append(current_loss)
        
        # Calculate rolling average
        if len(batch_losses) > window_size:
            batch_losses.pop(0)  # Remove oldest loss to maintain window size
        
        rolling_avg_loss = sum(batch_losses) / len(batch_losses)
        rolling_avg_losses.append(rolling_avg_loss)
        
        if batch % 100 == 0:
            current = batch * len(X)
            print(f"loss: {current_loss:>7f}, rolling avg loss: {rolling_avg_loss:>7f}  [{current:>5d}/{size:>5d}]")
    
    # Return the rolling average losses for possible visualization
    return rolling_avg_losses

In [16]:
model = model.to(device)

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

In [18]:
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [19]:
for i in range(100):
    train_loop(train_dataloader, classifier, criterion, optimizer)

  if self.orig_query.grad is None:


loss: 2.191085, rolling avg loss: 2.191085  [    0/89996]
loss: 1.908982, rolling avg loss: 1.885128  [  100/89996]
loss: 0.146316, rolling avg loss: 1.131842  [  200/89996]
loss: 1.177288, rolling avg loss: 0.676339  [  300/89996]
loss: 1.851651, rolling avg loss: 0.597197  [  400/89996]
loss: 0.480742, rolling avg loss: 0.647838  [  500/89996]
loss: 0.003085, rolling avg loss: 0.580665  [  600/89996]
loss: 0.104599, rolling avg loss: 0.479732  [  700/89996]
loss: 0.048798, rolling avg loss: 0.441239  [  800/89996]
loss: 0.267393, rolling avg loss: 0.431533  [  900/89996]
loss: 0.255576, rolling avg loss: 0.297778  [ 1000/89996]
loss: 0.001311, rolling avg loss: 0.504680  [ 1100/89996]
loss: 0.080491, rolling avg loss: 0.458503  [ 1200/89996]
loss: 0.412693, rolling avg loss: 0.358777  [ 1300/89996]
loss: 1.175157, rolling avg loss: 0.509278  [ 1400/89996]
loss: 0.014059, rolling avg loss: 0.300433  [ 1500/89996]
loss: 0.168135, rolling avg loss: 0.348651  [ 1600/89996]
loss: 0.182561


KeyboardInterrupt



In [20]:
def test_loop(dataloader, model, loss_fn, num_classes=9):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss = 0
    
    # Initialize confusion matrix
    confusion_matrix = torch.zeros(num_classes, num_classes)
    
    # Set the model to evaluation mode
    model.eval()
    
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            
            # Get predicted class
            predicted = pred.argmax(1)
            
            # Update confusion matrix
            for t, p in zip(y.view(-1), predicted.view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
    
    # Calculate average loss
    test_loss /= num_batches
    
    # Calculate metrics
    metrics = {}
    
    # Overall accuracy
    accuracy = confusion_matrix.diag().sum() / confusion_matrix.sum()
    
    # Per-class metrics
    for class_idx in range(num_classes):
        # True positives: diagonal elements
        tp = confusion_matrix[class_idx, class_idx]
        
        # False positives: sum of column minus true positive
        fp = confusion_matrix[:, class_idx].sum() - tp
        
        # False negatives: sum of row minus true positive
        fn = confusion_matrix[class_idx, :].sum() - tp
        
        # True negatives: all minus tp, fp, fn
        tn = confusion_matrix.sum() - tp - fp - fn
        
        # Calculate metrics
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        metrics[f"class_{class_idx}"] = {
            "precision": precision.item(),
            "recall": recall.item(),
            "f1": f1.item(),
            "true_positives": tp.item(),
            "false_positives": fp.item(),
            "true_negatives": tn.item(),
            "false_negatives": fn.item()
        }
    
    # Print summary
    print(f"Test Results: \n Accuracy: {(100*accuracy):>0.1f}%, Avg loss: {test_loss:>8f}")
    print("\nConfusion Matrix:")
    print(confusion_matrix)
    print("\nPer-class metrics:")
    for class_idx, class_metrics in metrics.items():
        print(f"{class_idx}:")
        print(f"  Precision: {class_metrics['precision']:.4f}")
        print(f"  Recall: {class_metrics['recall']:.4f}")
        print(f"  F1-Score: {class_metrics['f1']:.4f}")
        print(f"  TP: {class_metrics['true_positives']}, FP: {class_metrics['false_positives']}")
        print(f"  TN: {class_metrics['true_negatives']}, FN: {class_metrics['false_negatives']}")
        print("")
    
    return {
        "loss": test_loss,
        "accuracy": accuracy.item(),
        "confusion_matrix": confusion_matrix,
        "class_metrics": metrics
    }

In [21]:
test_loop(test_dataloader, classifier, criterion)

Test Results: 
 Accuracy: 91.3%, Avg loss: 0.314782

Confusion Matrix:
tensor([[1.2960e+03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.8000e+01,
         0.0000e+00, 0.0000e+00, 4.0000e+00],
        [6.0000e+00, 8.4000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 3.2000e+02, 3.0000e+00, 0.0000e+00, 1.4000e+01,
         0.0000e+00, 0.0000e+00, 2.0000e+00],
        [0.0000e+00, 0.0000e+00, 1.0000e+00, 6.3100e+02, 0.0000e+00, 1.0000e+00,
         0.0000e+00, 0.0000e+00, 1.0000e+00],
        [4.0000e+00, 0.0000e+00, 6.0000e+00, 0.0000e+00, 9.6900e+02, 2.3000e+01,
         1.6000e+01, 1.2000e+01, 5.0000e+00],
        [0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, 0.0000e+00, 5.5800e+02,
         0.0000e+00, 3.2000e+01, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0000e+00, 1.0000e+00, 3.0000e+00,
         6.8900e+02, 0.0000e+00, 4.2000e+01],
        [0.0000e+00, 0.0000e+00, 5.0000e

{'loss': 0.3147820494856667,
 'accuracy': 0.912674069404602,
 'confusion_matrix': tensor([[1.2960e+03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.8000e+01,
          0.0000e+00, 0.0000e+00, 4.0000e+00],
         [6.0000e+00, 8.4000e+02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 3.2000e+02, 3.0000e+00, 0.0000e+00, 1.4000e+01,
          0.0000e+00, 0.0000e+00, 2.0000e+00],
         [0.0000e+00, 0.0000e+00, 1.0000e+00, 6.3100e+02, 0.0000e+00, 1.0000e+00,
          0.0000e+00, 0.0000e+00, 1.0000e+00],
         [4.0000e+00, 0.0000e+00, 6.0000e+00, 0.0000e+00, 9.6900e+02, 2.3000e+01,
          1.6000e+01, 1.2000e+01, 5.0000e+00],
         [0.0000e+00, 0.0000e+00, 2.0000e+00, 0.0000e+00, 0.0000e+00, 5.5800e+02,
          0.0000e+00, 3.2000e+01, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 6.0000e+00, 1.0000e+00, 3.0000e+00,
          6.8900e+02, 0.0000e+00, 4.2000e+01],
         [0.0000

In [None]:
def plot_confusion_matrix(confusion_matrix, class_names=None, normalize=False, title="Confusion Matrix", cmap=plt.cm.Blues):
    """
    Plot a confusion matrix with clear labels and colorbar.
    
    Args:
        confusion_matrix: The confusion matrix from the test_loop function
        class_names: List of class names (optional, defaults to indices)
        normalize: Boolean to normalize values (default: False)
        title: Plot title (default: "Confusion Matrix")
        cmap: Colormap (default: plt.cm.Blues)
    """
    import numpy as np
    import matplotlib.pyplot as plt
    
    # Convert to numpy array if it's a torch tensor
    if hasattr(confusion_matrix, 'cpu'):
        cm = confusion_matrix.cpu().numpy()
    else:
        cm = np.array(confusion_matrix)
    
    # Set up class names
    if class_names is None:
        class_names = [str(i) for i in range(cm.shape[0])]
    
    # Normalize if requested
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = '.2f'
    else:
        fmt = 'd'
    
    # Create figure and axis
    plt.figure(figsize=(10, 8))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, fontsize=15)
    plt.colorbar()
    
    # Set up tick marks
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45, ha='right', fontsize=10)
    plt.yticks(tick_marks, class_names, fontsize=10)
    
    # Add text annotations to each cell
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt),
                     ha="center", va="center",
                     color="white" if cm[i, j] > thresh else "black",
                     fontsize=9)
    
    plt.tight_layout()
    plt.ylabel('True label', fontsize=12)
    plt.xlabel('Predicted label', fontsize=12)
    
    return plt

# Example usage within the test_loop function or after it returns:
def test_with_visualization(dataloader, model, loss_fn, num_classes=9):
    # Call the test_loop function
    results = test_loop(dataloader, model, loss_fn, num_classes)
    
    # Get the confusion matrix from results
    conf_matrix = results['confusion_matrix']
    
    # Optional: Define class names (replace with your actual class names)
    class_names = [label_names[value] for value in label_names]
    
    # Plot and display the confusion matrix
    plt.figure(figsize=(12, 10))
    
    # Plot raw counts
    plt.subplot(1, 2, 1)
    plot_confusion_matrix(
        conf_matrix, 
        class_names=class_names,
        title="Confusion Matrix (Counts)"
    )
    
    # Plot normalized (percentage)
    plt.subplot(1, 2, 2)
    plot_confusion_matrix(
        conf_matrix, 
        class_names=class_names,
        normalize=True,
        title="Confusion Matrix (Normalized)"
    )
    
    plt.tight_layout()
    plt.show()
    
    return results

In [None]:
classifier(torch.unsqueeze(dataset[2][0], dim=0))

In [None]:
dataset[2][1]