In [9]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [10]:
PROJECT_PATH = '/content/drive/MyDrive/classification_project'

In [11]:

import os
print("Files:", os.listdir(f"{PROJECT_PATH}/1-100GEOJSON"))
import pandas as pd
import geopandas as gpd
import json


all_files = os.listdir(f"{PROJECT_PATH}/1-100GEOJSON")
classification_stats = pd.DataFrame()

# Load one file to see structure
test_file = [f for f in all_files if f.endswith('.geojson')][0]
gdf = gpd.read_file(f"{PROJECT_PATH}/1-100GEOJSON/{test_file}")
print("\nSample file structure:")
print(gdf.columns)
print("\nClassification types:")
print(gdf['classification'].head())

Files: ['TCGA-AC-A3BB-01Z-00-DX1.CE889249-2A5E-44DA-B04E-746BE82CD805.geojson', 'TCGA-A7-A0CG-01Z-00-DX1.D77019C2-96B1-4EF5-A61E-5F2D5B8D9852.geojson', 'TCGA-AO-A0J5-01Z-00-DX1.20C14D0C-1A74-4FE9-A5E6-BDDCB8DE7714.geojson', 'TCGA-E9-A1N4-01Z-00-DX1.71c8d4a5-ec99-4012-9fe2-ddb3349ad5bc.geojson', 'TCGA-E9-A1NC-01Z-00-DX1.20edf036-8ba6-4187-a74c-124fc39f5aa1.geojson', 'TCGA-JL-A3YW-01Z-00-DX1.827C5C53-9C30-4307-802A-5A7896828A7F.geojson', 'TCGA-C8-A137-01Z-00-DX1.87F3775D-A401-4D5E-843F-8FB1D4BE97F8.geojson', 'TCGA-PL-A8LY-01A-02-DX2.6F9520F1-3210-4A96-81C2-A14424F650D1.geojson', 'TCGA-AO-A124-01Z-00-DX1.E3C7B017-6154-4630-9BDE-0CAC946D0209.geojson', 'TCGA-AO-A03U-01Z-00-DX1.AE2B55F3-8BA1-4546-82B7-4D2292BE1C78.geojson', 'TCGA-BH-A0H9-01Z-00-DX1.8AE869C6-5C78-4D52-AC8B-5B6FD5FD91AA.geojson', 'TCGA-A7-A4SF-01Z-00-DX1.CDCFD4BC-4363-4CF2-95F5-4922E04C3B9D.geojson', 'TCGA-AR-A24H-01Z-00-DX1.5CFC7E16-3F38-4531-968C-A4E4C9D00659.geojson', 'TCGA-C8-A1HI-01Z-00-DX1.C6D0F8B8-55ED-477F-BAF7-AA05D04

In [12]:
import pandas as pd
import geopandas as gpd
import json

geojson_files = [f for f in os.listdir(f"{PROJECT_PATH}/1-100GEOJSON") if f.endswith('.geojson')]

def process_geojson(file_path):
    gdf = gpd.read_file(file_path)

    gdf['classification'] = gdf['classification'].apply(
        lambda x: json.loads(x) if isinstance(x, str) else x
    )
    gdf['class_name'] = gdf['classification'].apply(
        lambda x: x['name'] if isinstance(x, dict) and 'name' in x else None
    )
    return gdf

sample_file = geojson_files[0]
sample_gdf = process_geojson(f"{PROJECT_PATH}/1-100GEOJSON/{sample_file}")
print("\nSample data structure:")
print(sample_gdf[['id', 'objectType', 'class_name']].head())


all_classes = []
for file in geojson_files:
    gdf = process_geojson(f"{PROJECT_PATH}/1-100GEOJSON/{file}")
    all_classes.extend(gdf['class_name'].tolist())

class_distribution = pd.Series(all_classes).value_counts()
print(class_distribution)


Sample data structure:
                                     id  objectType   class_name
0  66dec096-daf4-4d19-bac2-6160e7b99d32  annotation  vasculature
adipose tissue                 20206
vasculature                     1782
nomal breast gland              1048
invasive cancer                  447
Immune cells                     248
cancer in situ                   229
atypical ductal hyperplasia      229
Necrosis                          41
Name: count, dtype: int64


In [13]:
# Dictionary to store image-annotation pairs
data_pairs = {}
for geojson_file in geojson_files:
    base_name = geojson_file.replace('.geojson', '')
    data_pairs[base_name] = {
        'geojson': geojson_file,
        'csv': base_name + '.csv'
    }

print(f"\nNumber of image-annotation pairs: {len(data_pairs)}")


Number of image-annotation pairs: 100


In [14]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.cuda as cuda
import gc
from tqdm import tqdm
import logging
import sys
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.utils.class_weight import compute_class_weight
from shapely.geometry import Point
import cv2
from scipy import stats
from skimage import filters
from sklearn.metrics import confusion_matrix, classification_report, cohen_kappa_score
from torch.optim.lr_scheduler import OneCycleLR


In [15]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training.log'),
        logging.StreamHandler(sys.stdout)
    ]
)

In [16]:

def setup_training_config():
    config = {
        # Model parameters
        'hidden_dim': 768,
        'num_heads': 12,
        'num_layers': 12,
        'patch_size': 16,

        # Training parameters
        'batch_size': 32,
        'num_epochs': 100,
        'learning_rate': 2e-4,
        'weight_decay': 0.05,
        'warmup_epochs': 10,

        # Regularization
        'dropout': 0.1,
        'label_smoothing': 0.1,

        # Optimizer parameters
        'optimizer_params': {
            'betas': (0.9, 0.999),
            'eps': 1e-8
        },

        # Learning rate scheduler
        'scheduler_params': {
            'num_warmup_steps': 0,  # Will be calculated based on warmup_epochs
            'num_training_steps': 0  # Will be calculated based on dataset size
        }
    }

    return config


config = setup_training_config()

PATCH_SIZE = config['patch_size']
BATCH_SIZE = config['batch_size']
num_epochs = config['num_epochs']
CHECKPOINT_FREQ = 5
LEARNING_RATE = config['learning_rate']

def setup_training():
    if torch.cuda.is_available():
        logging.info(f"GPU available: {torch.cuda.get_device_name(0)}")
        logging.info(f"Initial GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")

        torch.cuda.empty_cache()
        gc.collect()
    else:
        logging.warning("No GPU available, using CPU")

    checkpoint_dir = os.path.join(PROJECT_PATH, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)
    return checkpoint_dir

def save_checkpoint(epoch, model, optimizer, loss, acc, checkpoint_dir):
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'accuracy': acc
    }, checkpoint_path)
    logging.info(f"Saved checkpoint for epoch {epoch}")

class CancerDataset(Dataset):
    def __init__(self, geojson_dir, transform=None, phase='train'):
        self.geojson_dir = geojson_dir
        self.transform = transform
        self.phase = phase


        self.class_mapping = {
            'adipose tissue': 0,
            'vasculature': 1,
            'nomal breast gland': 2,
            'invasive cancer': 3,
            'Immune cells': 4,
            'atypical ductal hyperplasia': 5,
            'cancer in situ': 6,
            'Necrosis': 7
        }

        self.annotations = self._load_annotations()

        # Calculate class weights for balanced sampling
        self.class_weights = self._calculate_class_weights()

        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.RandomRotation(10),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                  std=[0.229, 0.224, 0.225])
            ])

    def _load_annotations(self):
        annotations = []
        geojson_files = [f for f in os.listdir(self.geojson_dir) if f.endswith('.geojson')]

        # Add progress bar
        for i, file in enumerate(tqdm(geojson_files, desc='Loading annotations')):
            try:
                gdf = gpd.read_file(os.path.join(self.geojson_dir, file))
                gdf['classification'] = gdf['classification'].apply(
                    lambda x: json.loads(x) if isinstance(x, str) else x
                )
                gdf['class_name'] = gdf['classification'].apply(
                    lambda x: x['name'] if isinstance(x, dict) and 'name' in x else None
                )

                for _, row in gdf.iterrows():
                    if row['class_name'] in self.class_mapping:
                        annotations.append({
                            'file': file,
                            'geometry': row['geometry'],
                            'class_name': row['class_name'],
                            'class': self.class_mapping[row['class_name']]
                        })
            except Exception as e:
                print(f"Error processing file {file}: {str(e)}")

        return annotations

    def calculate_class_weights(dataset):
      # Count samples per class
      class_counts = {}
      for annotation in dataset.annotations:
          class_counts[annotation['class']] = class_counts.get(annotation['class'], 0) + 1

      # Calculate balanced weights
      total_samples = len(dataset.annotations)
      num_classes = len(dataset.class_mapping)
      weights = {}

      for cls, count in class_counts.items():
          # Modified weight calculation using effective number of samples
          beta = 0.9999
          effective_num = 1.0 - np.power(beta, count)
          weights[cls] = (1.0 - beta) / effective_num

      # Normalize weights
      weight_sum = sum(weights.values())
      weights = {cls: weight/weight_sum * num_classes for cls, weight in weights.items()}

      return weights

    def __getitem__(self, idx):
        annotation = self.annotations[idx]

        features = self._extract_patch_features(annotation['geometry'])
        features = torch.FloatTensor(features).permute(2, 0, 1)

        # Apply transformations
        if self.transform:
            features = self.transform(features)

        return features, annotation['class']

    def _extract_patch_features(self, geometry):
        minx, miny, maxx, maxy = geometry.bounds
        area = geometry.area
        perimeter = geometry.length


        compactness = 4 * np.pi * area / (perimeter * perimeter) if perimeter > 0 else 0
        elongation = (maxx - minx) / (maxy - miny) if (maxy - miny) > 0 else 1

        features = np.zeros((PATCH_SIZE, PATCH_SIZE, 5))



        features[:,:,0] = area / (PATCH_SIZE * PATCH_SIZE)
        features[:,:,1] = perimeter / (4 * PATCH_SIZE)
        features[:,:,2] = compactness
        features[:,:,3] = elongation
        features[:,:,4] = geometry.area / geometry.convex_hull.area if not geometry.is_empty else 1

        return features

    def __len__(self):
        return len(self.annotations)

    def get_sample_weights(self):
        weights = [self.class_weights[annotation['class']] for annotation in self.annotations]
        return torch.DoubleTensor(weights)


In [17]:
def process_whole_slide(wsi_path):
    """
    """

    wsi = openslide.OpenSlide(wsi_path)
    width = wsi.dimensions[0]
    height = wsi.dimensions[1]

    tile_size = 512

    tiles = []
    coords = []

    thumbnail = wsi.get_thumbnail((width//32, height//32))
    thumbnail_gray = cv2.cvtColor(np.array(thumbnail), cv2.COLOR_RGB2GRAY)
    tissue_mask = thumbnail_gray < filters.threshold_otsu(thumbnail_gray)

    # Extract tiles
    for y in range(0, height, tile_size):
        for x in range(0, width, tile_size):
            mask_x = x // 32
            mask_y = y // 32
            if mask_x < tissue_mask.shape[1] and mask_y < tissue_mask.shape[0]:
                if tissue_mask[mask_y, mask_x]:
                    tile = wsi.read_region((x, y), 0, (tile_size, tile_size))
                    tile = tile.convert('RGB')

                    # Filter out background tiles
                    tile_array = np.array(tile)
                    if np.mean(tile_array) < 240:
                        tiles.append(tile_array)
                        coords.append((x, y))

    return np.array(tiles), coords


In [18]:
def pixel_wise_labeling(model, wsi):
    """
    """
    height, width = wsi.shape[:2]

    prediction_mask = np.zeros((height, width), dtype=np.uint8)
    window_size = 512
    stride = 256

    for y in range(0, height - window_size + 1, stride):
        for x in range(0, width - window_size + 1, stride):
            patch = wsi[y:y+window_size, x:x+window_size]
            patch = cv2.resize(patch, (224, 224))
            patch = patch / 255.0
            patch = np.expand_dims(patch, axis=0)


            pred = model.predict(patch)
            label = np.argmax(pred[0])

            prediction_mask[y:y+window_size, x:x+window_size] = label

    return prediction_mask

In [19]:
def compare_baseline(our_results, baseline_results):
    """

    """
    metrics = {
        'confusion_matrix': confusion_matrix(baseline_results, our_results),
        'classification_report': classification_report(baseline_results, our_results, output_dict=True),
        'cohen_kappa': cohen_kappa_score(baseline_results, our_results)
    }

    class_accuracies = {}
    for class_name in set(baseline_results):
        class_mask = (baseline_results == class_name)
        class_acc = np.mean(our_results[class_mask] == baseline_results[class_mask])
        class_accuracies[f'class_{class_name}_accuracy'] = class_acc

    metrics['class_accuracies'] = class_accuracies

    # Statistical significance
    contingency_table = pd.crosstab(pd.Series(baseline_results), pd.Series(our_results))
    chi2, p_value = stats.chi2_contingency(contingency_table)[:2]
    metrics['statistical_tests'] = {
        'chi2': chi2,
        'p_value': p_value
    }

    return metrics

In [20]:
class VisionTransformer(nn.Module):
    def __init__(self, num_classes=8, patch_size=16, hidden_dim=768, num_heads=12, num_layers=12):
        super().__init__()

        # Input normalization
        self.input_norm = nn.LayerNorm(5)  # 5 channels from geometric features

        # Improved patch embedding
        self.patch_embed = nn.Sequential(
            nn.Conv2d(5, hidden_dim, kernel_size=patch_size, stride=patch_size),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )

        # Position embedding with learned parameters
        self.pos_embed = nn.Parameter(torch.zeros(1, (224 // patch_size) ** 2 + 1, hidden_dim))
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))

        # Transformer encoder with layer normalization before attention and MLP
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            dim_feedforward=hidden_dim * 4,
            dropout=0.1,
            activation='gelu',
            norm_first=True  # Pre-norm architecture
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Classification head with dropout
        self.head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        # Input shape: [B, C, H, W]
        B = x.shape[0]

        # Normalize input
        x = x.permute(0, 2, 3, 1)  # [B, H, W, C]
        x = self.input_norm(x)
        x = x.permute(0, 3, 1, 2)  # [B, C, H, W]

        # Patch embedding
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]

        # Add classification token and position embeddings
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed

        # Transform
        x = self.transformer(x)

        # Classification from [CLS] token
        x = x[:, 0]
        x = self.head(x)

        return x

In [21]:
def sample_weights(dataset):

    labels = []
    subset_indices = dataset.indices if hasattr(dataset, 'indices') else range(len(dataset.dataset))

    for idx in subset_indices:
        if hasattr(dataset, 'dataset'):
            _, label = dataset.dataset[idx]
        else:
            _, label = dataset[idx]
        labels.append(label)

    labels = np.array(labels)

    # Calculate distribution
    unique_labels, counts = np.unique(labels, return_counts=True)
    print("Class distribution:")
    for label, count in zip(unique_labels, counts):
        print(f"Class {label}: {count} samples")

    # Calculate weights
    total_samples = len(labels)
    class_weights = total_samples / (len(unique_labels) * counts)

    # Create sample weights
    sample_weights = [class_weights[label] for label in labels]
    print("Sample weights calculated successfully")
    return torch.FloatTensor(sample_weights)

In [22]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, checkpoint_dir, gradient_clip_val=1.0, grad_accum_steps=2):
    best_val_acc = 0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': [],
               'train_f1': [], 'val_f1': []}

    patience = 10
    early_stop_counter = 0
    best_val_f1 = 0

    for epoch in range(num_epochs):
        logging.info(f'\nEpoch {epoch+1}/{num_epochs}')
        logging.info('-' * 30)

        # Training phase
        model.train()
        train_loss = 0
        train_correct = 0
        train_total = 0
        train_preds = []
        train_targets = []
        optimizer.zero_grad()  # Zero gradients at the start of epoch

        train_pbar = tqdm(train_loader, desc=f'Training Epoch {epoch+1}')

        for batch_idx, (inputs, targets) in enumerate(train_pbar):
            try:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # Normalize loss for gradient accumulation
                loss = loss / grad_accum_steps
                loss.backward()

                # Gradient accumulation
                if (batch_idx + 1) % grad_accum_steps == 0 or batch_idx == len(train_loader) - 1:
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip_val)
                    optimizer.step()
                    optimizer.zero_grad()

                    if scheduler is not None:
                        scheduler.step()

                train_loss += loss.item() * grad_accum_steps
                _, predicted = outputs.max(1)
                train_total += targets.size(0)
                train_correct += predicted.eq(targets).sum().item()
                train_preds.extend(predicted.cpu().numpy())
                train_targets.extend(targets.cpu().numpy())

                current_acc = 100. * train_correct / train_total
                train_pbar.set_postfix({
                    'loss': f'{loss.item() * grad_accum_steps:.4f}',
                    'acc': f'{current_acc:.2f}%'
                })

            except RuntimeError as e:
                if "out of memory" in str(e):
                    logging.error("GPU OOM encountered, clearing cache and skipping batch")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

        train_loss = train_loss / len(train_loader)
        train_acc = 100. * train_correct / train_total
        train_f1 = f1_score(train_targets, train_preds, average='weighted')

        # Validation phase
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_targets = []

        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc='Validation')
            for inputs, targets in val_pbar:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()
                val_preds.extend(predicted.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())

                current_acc = 100. * val_correct / val_total
                val_pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{current_acc:.2f}%'
                })

        val_loss = val_loss / len(val_loader)
        val_acc = 100. * val_correct / val_total
        val_f1 = f1_score(val_targets, val_preds, average='weighted')

        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['train_f1'].append(train_f1)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)

        logging.info(f'\nEpoch Summary:')
        logging.info(f'Training Loss: {train_loss:.4f} | Training Acc: {train_acc:.2f}% | Training F1: {train_f1:.4f}')
        logging.info(f'Validation Loss: {val_loss:.4f} | Validation Acc: {val_acc:.2f}% | Validation F1: {val_f1:.4f}')

        # Save best model based on F1 score
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_f1': best_val_f1,
                'best_val_acc': val_acc,
            }, os.path.join(checkpoint_dir, 'best_model.pth'))
            early_stop_counter = 0
            logging.info(f'Saved new best model with validation F1: {val_f1:.4f} | Acc: {val_acc:.2f}%')
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            logging.info("Early stopping triggered")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    return model, history


In [23]:
def plot_training_history(history):
    plt.figure(figsize=(12, 4))

    # Loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Training Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(PROJECT_PATH, 'training_history.png'))
    plt.close()


In [24]:
def evaluate_model(model, test_loader, device):
    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.numpy())

    return all_preds, all_targets


In [25]:
def plot_confusion_matrix(predictions, targets, class_names):
    cm = confusion_matrix(targets, predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.tight_layout()
    plt.savefig(os.path.join(PROJECT_PATH, 'confusion_matrix.png'))
    plt.close()
    return cm

In [26]:

def visualize_wsi_predictions(predictions, coordinates, save_path):
    """
    """

    colors = {
        0: [255, 0, 0],    # Red for Amplified
        1: [0, 255, 0],    # Green for Normal
        2: [0, 0, 255]     # Blue for Non-Amplified
    }

    height = max(y for _, y in coordinates) + 512
    width = max(x for x, _ in coordinates) + 512
    visualization = np.zeros((height, width, 3), dtype=np.uint8)

    for pred, (x, y) in zip(predictions, coordinates):
        color = colors[pred]
        visualization[y:y+512, x:x+512] = color


    Image.fromarray(visualization).save(save_path)

In [27]:
class GeometricFeatureTransform:
    def __init__(self):
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406, 0.449, 0.449],
            std=[0.229, 0.224, 0.225, 0.226, 0.226]
        )

    def __call__(self, x):
        return self.normalize(x)

In [28]:

def main():
    try:
        if not os.path.exists(GEOJSON_DIR):
            raise ValueError(f"GeoJSON directory not found: {GEOJSON_DIR}")

        torch.manual_seed(42)
        np.random.seed(42)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")


        config = setup_training_config()
        batch_size = config['batch_size']


        train_transform = GeometricFeatureTransform()
        val_transform = GeometricFeatureTransform()


        dataset = CancerDataset(GEOJSON_DIR, transform=train_transform)
        total_size = len(dataset)

        # Stratified split
        labels = [annotation['class'] for annotation in dataset.annotations]
        train_idx, temp_idx = train_test_split(
            range(total_size),
            test_size=0.2,
            stratify=labels,
            random_state=42
        )
        val_idx, test_idx = train_test_split(
            temp_idx,
            test_size=0.5,
            stratify=[labels[i] for i in temp_idx],
            random_state=42
        )

        # Dataset subsets
        train_dataset = torch.utils.data.Subset(dataset, train_idx)
        val_dataset = torch.utils.data.Subset(dataset, val_idx)
        test_dataset = torch.utils.data.Subset(dataset, test_idx)

        print(f"Train size: {len(train_dataset)}")
        print(f"Val size: {len(val_dataset)}")
        print(f"Test size: {len(test_dataset)}")

        # Class weights
        class_counts = np.zeros(len(dataset.class_mapping))
        for idx in train_idx:
            _, label = dataset[idx]
            class_counts[label] += 1

        # Inverse frequency weighting with smoothing
        beta = 0.9999
        effective_num = 1.0 - np.power(beta, class_counts)
        weights_per_class = (1.0 - beta) / np.where(effective_num < 1e-8, 1e-8, effective_num)
        weights_per_class = weights_per_class / np.sum(weights_per_class) * len(weights_per_class)

        # Sampler
        sample_weights = [weights_per_class[labels[idx]] for idx in train_idx]
        sampler = WeightedRandomSampler(
            weights=sample_weights,
            num_samples=len(train_idx),
            replacement=True
        )

        # Data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            sampler=sampler,
            num_workers=2,  # Reduced from 4 to 2
            pin_memory=True if torch.cuda.is_available() else False
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True if torch.cuda.is_available() else False
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True if torch.cuda.is_available() else False
        )

        # Initialize model
        print("Initializing model...")
        model = ImprovedVisionTransformer(
            num_classes=len(dataset.class_mapping),
            patch_size=config['patch_size'],
            hidden_dim=config['hidden_dim'],
            num_heads=config['num_heads'],
            num_layers=config['num_layers']
        )

        model = model.to(device)

        # Loss & Optimizer
        criterion = create_criterion(weights_per_class, config)
        optimizer = create_optimizer(model, config)







        num_warmup_steps = len(train_loader) * 5
        num_training_steps = len(train_loader) * NUM_EPOCHS

        from transformers import get_cosine_schedule_with_warmup
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        )






        checkpoint_dir = os.path.join(PROJECT_PATH, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Train model
        model, history = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            num_epochs=NUM_EPOCHS,
            device=device,
            checkpoint_dir=checkpoint_dir,
            gradient_clip_val=gradient_clip_val,
            grad_accum_steps=grad_accum_steps
        )


        plot_training_history(history)

        # Load best model
        print("Loading best model...")
        best_model_path = os.path.join(PROJECT_PATH, 'best_model.pth')
        if os.path.exists(best_model_path):
            try:
                state_dict = torch.load(best_model_path)
                if 'model_state_dict' in state_dict:
                    model.load_state_dict(state_dict['model_state_dict'])
                else:
                    model.load_state_dict(state_dict)
                print("Successfully loaded best model")
            except Exception as e:
                print(f"Could not load model: {str(e)}")
                print("Continuing with current model state")

        # Evaluate model
        print("Evaluating model...")
        predictions, targets = evaluate_model(model, test_loader, device)

        # Generate confusion matrix
        print("Generating confusion matrix...")
        class_names = list(dataset.class_mapping.keys())
        cm = plot_confusion_matrix(predictions, targets, class_names)

        # Generate classification report
        print("Classification report...")
        report = classification_report(
            targets,
            predictions,
            target_names=class_names,
            digits=4
        )
        print("\nClassification Report:")
        print(report)






        report_path = os.path.join(PROJECT_PATH, 'classification_report.txt')
        with open(report_path, 'w') as f:
            f.write(report)

        final_results = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history,
            'class_mapping': dataset.class_mapping,
            'final_test_predictions': predictions,
            'final_test_targets': targets,
            'class_weights': weights_per_class.tolist(),
            'confusion_matrix': cm.tolist()
        }

        torch.save(final_results, os.path.join(PROJECT_PATH, 'final_model_results.pth'))

    except Exception as e:
        print(f"\nAn error occurred: {str(e)}")
        import traceback
        traceback.print_exc()
        raise

    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

if __name__ == '__main__':
    main()

Using device: cuda


Loading annotations: 100%|██████████| 100/100 [00:14<00:00,  6.68it/s]


An error occurred: 'CancerDataset' object has no attribute '_calculate_class_weights'



Traceback (most recent call last):
  File "<ipython-input-28-793b9e574260>", line 20, in main
    dataset = CancerDataset(GEOJSON_DIR, transform=train_transform)
  File "<ipython-input-16-050b439081ab>", line 90, in __init__
    self.class_weights = self._calculate_class_weights()
AttributeError: 'CancerDataset' object has no attribute '_calculate_class_weights'


AttributeError: 'CancerDataset' object has no attribute '_calculate_class_weights'