In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import pandas as pd
import numpy as np
from pathlib import Path
import tifffile
import cv2
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler

  warn(


In [2]:


class DiatomDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None):
        """
        Args:
            csv_path: Path to the CSV file with annotations
            img_dir: Directory with all the images
            transform: Optional transform to be applied on a sample
        """
        self.data = pd.read_csv(csv_path)
        self.img_dir = Path(img_dir)
        self.transform = transform
        
        # Extract Cocconeis counts and normalize them
        self.counts = self.data['Cocconeis'].values
        self.scaler = StandardScaler()
        self.normalized_counts = self.scaler.fit_transform(self.counts.reshape(-1, 1))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Get image path and load image
        img_name = self.data.iloc[idx]['micrograph ID']
        img_path = self.img_dir / f"{img_name}.tif"
        image = tifffile.imread(str(img_path))
        
        # Apply preprocessing
        image = self.preprocess_image(image)
        
        if self.transform:
            image = self.transform(image)
        
        # Get normalized count
        count = self.normalized_counts[idx]
        
        return image, torch.tensor(count, dtype=torch.float32)

    def preprocess_image(self, image):
        """Enhanced preprocessing pipeline."""
        # Convert to RGB if needed
        if len(image.shape) == 3:
            img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image.shape[-1] == 3 else image
        else:
            img_rgb = np.stack([image] * 3, axis=-1)
        
        # Resize
        resized = cv2.resize(img_rgb, (512, 512), interpolation=cv2.INTER_AREA)
        
        # Enhance contrast
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = np.zeros_like(resized, dtype=np.float32)
        for i in range(3):
            enhanced[:,:,i] = clahe.apply((resized[:,:,i]).astype(np.uint8))
        
        # Normalize
        normalized = (enhanced.astype(float) - 127.5) / 127.5
        return normalized

class DiatomCountPredictor(nn.Module):
    def __init__(self):
        super(DiatomCountPredictor, self).__init__()
        # Use EfficientNet-B0 as backbone
        self.backbone = models.efficientnet_b0(pretrained=True)
        
        # Modify the classifier head
        num_ftrs = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )
        
        # Add attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(1280, 1, kernel_size=1),  # EfficientNet-B0 feature size
            nn.Sigmoid()
        )

    def forward(self, x):
        # Get features from backbone (before classifier)
        features = self.backbone.features(x)
        
        # Apply attention
        attention_weights = self.attention(features)
        features = features * attention_weights
        
        # Continue with classifier
        x = self.backbone.avgpool(features)
        x = torch.flatten(x, 1)
        x = self.backbone.classifier(x)
        return x

def train_model(model, train_loader, val_loader, num_epochs=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    criterion = nn.MSELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        predictions = []
        actual = []
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                predictions.extend(outputs.cpu().numpy())
                actual.extend(labels.cpu().numpy())
        
        # Calculate percentage error
        pred_array = np.array(predictions)
        actual_array = np.array(actual)
        percentage_error = np.mean(np.abs((pred_array - actual_array) / actual_array)) * 100
        
        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        print(f'Mean Percentage Error: {percentage_error:.2f}%')
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')

# Data augmentation transforms
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomRotation(30),
    transforms.RandomAffine(degrees=0, scale=(0.8, 1.2)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
dataset = DiatomDataset(
    csv_path='Kraken_2023_measurements.csv',
    img_dir='/projects/genomic-ml/da2343/diatom/train',
    transform=train_transforms
)

# Create k-fold splits
kfold = KFold(n_splits=5, shuffle=True)