In [1]:
import os
import zipfile
import concurrent.futures
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import numpy as np
from io import BytesIO
from pathlib import Path

class FlowImageDataset(Dataset):
    def __init__(self, zip_path, split_folder, transform=None):
        self.zip_path = zip_path
        self.split_folder = f"dataset/{split_folder}"  # 'dataset/train' or 'dataset/val'
        self.transform = transform
        self.samples = []
        self.classes = ['no', 'sphere', 'vort']
        
        # Scan the zipfile to find all relevant samples
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            for file_info in zip_ref.infolist():
                path = Path(file_info.filename)
                # Check if file is in the correct split folder and is a numpy file
                if (len(path.parts) > 2 and 
                    path.parts[0] == 'dataset' and
                    path.parts[1] == split_folder and
                    len(path.parts) > 3 and
                    path.parts[2] in self.classes and 
                    path.suffix == '.npy'):
                    class_idx = self.classes.index(path.parts[2])
                    self.samples.append((file_info.filename, class_idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        file_path, label = self.samples[idx]
        
        # Load numpy array directly from zip
        with zipfile.ZipFile(self.zip_path, 'r') as zip_ref:
            with zip_ref.open(file_path) as file:
                # Load numpy array from bytes
                image_bytes = BytesIO(file.read())
                image = np.load(image_bytes)
                image = torch.from_numpy(image).float()
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

# Define transforms including proper normalization
transform = transforms.Compose([
    transforms.Lambda(lambda x: x if x.shape[0] == 1 else x.unsqueeze(0)),])
    # We can add additional transformations here if needed
    # Note: Images are already min-max normalized according to the description
# Create datasets directly from the zip file
# The paths in the zip file include 'dataset/' prefix
train_dataset = FlowImageDataset(zip_path='dataset.zip', split_folder='train', transform=transform)
val_dataset = FlowImageDataset(zip_path='dataset.zip', split_folder='val', transform=transform)
train_dataset = FlowImageDataset(zip_path='dataset.zip', split_folder='train', transform=transform)
val_dataset = FlowImageDataset(zip_path='dataset.zip', split_folder='val', transform=transform)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Create data loaders (for efficient batch loading)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

Training dataset size: 30000
Validation dataset size: 7500


In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models

# Define the model
class FlowResNet(nn.Module):
    def __init__(self, num_classes=3):
        super(FlowResNet, self).__init__()
        # Load a pre-trained ResNet model
        self.base_model = models.resnet18(pretrained=True)
        
        # Modify the first convolutional layer to accept single-channel input
        self.base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        
        # Replace the final fully connected layer
        in_features = self.base_model.fc.in_features
        self.base_model.fc = nn.Linear(in_features, num_classes)
        
    def forward(self, x):
        return self.base_model(x)

# Initialize model, loss function and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = FlowResNet(num_classes=len(train_dataset.classes)).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25):
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        
        # Use non_blocking=True for asynchronous CPU->GPU transfers
        for inputs, labels in tqdm(train_loader, desc="Training"):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc.item())
        
        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        
        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc="Validation"):
                inputs = inputs.to(device, non_blocking=True)
                labels = labels.to(device, non_blocking=True)
                
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                running_corrects += torch.sum(preds == labels.data)
        
        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = running_corrects.double() / len(val_loader.dataset)
        val_losses.append(epoch_loss)
        val_accs.append(epoch_acc.item())
        
        print(f'Val Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
        
        # Update scheduler
        scheduler.step(epoch_loss)
        
        # Save best model
        if epoch_acc > best_val_acc:
            best_val_acc = epoch_acc
            torch.save(model.state_dict(), 'flow_classifier_best.pth')
            print(f'Saved new best model with validation accuracy: {best_val_acc:.4f}')
    
    # Load best model
    model.load_state_dict(torch.load('flow_classifier_best.pth', map_location=device))
    
    return model, train_losses, val_losses, train_accs, val_accs

# Train the model
# Make sure we're using the GPU model
model = model.to(device)
model, train_losses, val_losses, train_accs, val_accs = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=20)

# Plot training and validation metrics
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(len(train_losses)), train_losses, label='Training Loss')
plt.plot(range(len(val_losses)), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves')

plt.subplot(1, 2, 2)
plt.plot(range(len(train_accs)), train_accs, label='Training Accuracy')
plt.plot(range(len(val_accs)), val_accs, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Accuracy Curves')
plt.tight_layout()
plt.show()

# Evaluate on validation set and create confusion matrix
def evaluate_model(model, data_loader, classes):
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc="Evaluating"):
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Create confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()
    
    # Print classification report
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=classes))

# Make sure model is on GPU again
model = model.to(device)
# Evaluate the model
evaluate_model(model, val_loader, train_dataset.classes)

Using device: cuda
Epoch 1/20
----------




Training:   0%|          | 0/938 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x75f6ecb02dd0>
Traceback (most recent call last):
  File "/home/ag42/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/ag42/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.10/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x75f6ecb02dd0>
Traceback (most recent call last):
  File "/home/ag42/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/home/ag42/.local/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1601, in _shutdown_workers
    if w.i

KeyboardInterrupt: 

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x75f6f787ea10>>
Traceback (most recent call last):
  File "/home/ag42/.local/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
