In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:

import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import cv2
from PIL import Image

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
class Config:
    # Paths (adjust these to your Kaggle paths)
    DATA_DIR = '/kaggle/input/recursion-cellular-image-classification'
    TRAIN_CSV = f'{DATA_DIR}/train.csv'
    TEST_CSV = f'{DATA_DIR}/test.csv'
    
    # Model settings
    MODEL_NAME = 'efficientnet_b3'  # Fast and accurate
    IMG_SIZE = 320  # Reduced from 512 for speed
    BATCH_SIZE = 32  # Adjust based on GPU memory
    EPOCHS = 40  # Reduced for time
    LR = 3e-4
    
    # Training settings
    NUM_WORKERS = 2
    SEED = 42
    NUM_CLASSES = 1108
    
    # Use only HUVEC cell type for speed (you can add more if time permits)
    CELL_TYPES = ['HUVEC']  # Add 'RPE', 'HEPG2', 'U2OS' if you have time
    
    # sirna needs to be converted to numeric labels
    CONVERT_SIRNA = True



Using device: cuda


In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(Config.SEED)

# Dataset class
class CellularDataset(Dataset):
    def __init__(self, df, data_dir, mode='train', transform=None):
        self.df = df
        self.data_dir = data_dir
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def load_image(self, row):
        """Load 6-channel image"""
        if self.mode == 'train':
            exp = row['experiment']
            plate = row['plate']
            well = row['well']
            site = row['site']
            path_template = f'{self.data_dir}/train/{exp}/Plate{plate}/{well}_s{site}_w'
        else:
            img_id = row['id_code']
            exp = row['experiment']
            plate = row['plate']
            well = row['well']
            site = 1  # Test images are site 1
            path_template = f'{self.data_dir}/test/{exp}/Plate{plate}/{well}_s{site}_w'
        
        # Load all 6 channels
        channels = []
        for i in range(1, 7):
            img_path = f'{path_template}{i}.png'
            if os.path.exists(img_path):
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                channels.append(img)
            else:
                # Fallback if file doesn't exist
                channels.append(np.zeros((512, 512), dtype=np.uint8))
        
        # Stack channels and resize
        img = np.stack(channels, axis=-1)
        img = cv2.resize(img, (Config.IMG_SIZE, Config.IMG_SIZE))
        
        return img
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = self.load_image(row)
        
        # Normalize to [0, 1]
        img = img.astype(np.float32) / 255.0
        
        if self.transform:
            # Convert to PIL for transforms (handle 6 channels)
            img = torch.from_numpy(img).permute(2, 0, 1)  # C, H, W
        else:
            img = torch.from_numpy(img).permute(2, 0, 1)
        
        if self.mode == 'train':
            label = row['label']
            return img, label
        else:
            return img, row['id_code']


In [3]:
class CellularModel(nn.Module):
    def __init__(self, model_name, num_classes, in_channels=6):
        super().__init__()
        # Load pretrained model
        self.backbone = timm.create_model(model_name, pretrained=True, in_chans=3)
        
        # Modify first conv layer to accept 6 channels
        if hasattr(self.backbone, 'conv_stem'):
            old_conv = self.backbone.conv_stem
            self.backbone.conv_stem = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            # Initialize with average of pretrained weights
            with torch.no_grad():
                self.backbone.conv_stem.weight[:, :3] = old_conv.weight
                self.backbone.conv_stem.weight[:, 3:] = old_conv.weight
        
        # Get number of features
        n_features = self.backbone.get_classifier().in_features
        self.backbone.reset_classifier(0)  # Remove classifier
        
        # Custom classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(n_features, num_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

In [4]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': running_loss/len(loader), 'acc': 100.*correct/total})
    
    return running_loss/len(loader), 100.*correct/total

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc='Validation'):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(loader), 100.*correct/total


In [5]:
def main():
    # Load data
    print("Loading data...")
    train_df = pd.read_csv(Config.TRAIN_CSV)
    test_df = pd.read_csv(Config.TEST_CSV)
    
    # Extract cell type from experiment column
    train_df['cell_type'] = train_df['experiment'].str.split('-').str[0]
    test_df['cell_type'] = test_df['experiment'].str.split('-').str[0]
    
    # Filter by cell type for speed
    train_df = train_df[train_df['cell_type'].isin(Config.CELL_TYPES)].reset_index(drop=True)
    test_df = test_df[test_df['cell_type'].isin(Config.CELL_TYPES)].reset_index(drop=True)
    
    print(f"Training samples: {len(train_df)}")
    print(f"Test samples: {len(test_df)}")
    print(f"Cell types in train: {train_df['cell_type'].unique()}")
    
    # Convert sirna labels to numeric (sirna_1 -> 1, sirna_10 -> 10, etc.)
    train_df['sirna_id'] = train_df['sirna'].str.replace('sirna_', '').astype(int)
    
    # Create label mapping (need to map to 0-indexed consecutive integers)
    unique_sirnas = sorted(train_df['sirna_id'].unique())
    sirna_to_label = {sirna: idx for idx, sirna in enumerate(unique_sirnas)}
    train_df['label'] = train_df['sirna_id'].map(sirna_to_label)
    
    print(f"Number of unique sirnas: {len(unique_sirnas)}")
    print(f"Label range: 0 to {train_df['label'].max()}")
    
    # Update NUM_CLASSES based on actual data
    Config.NUM_CLASSES = len(unique_sirnas)
    
    # Split train/val
    train_data, val_data = train_test_split(train_df, test_size=0.15, random_state=Config.SEED, stratify=train_df['label'])
    
    # Create datasets
    train_dataset = CellularDataset(train_data, Config.DATA_DIR, mode='train')
    val_dataset = CellularDataset(val_data, Config.DATA_DIR, mode='train')
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, 
                             shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, 
                           shuffle=False, num_workers=Config.NUM_WORKERS, pin_memory=True)
    
    # Create model
    print("Creating model...")
    model = CellularModel(Config.MODEL_NAME, Config.NUM_CLASSES).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS)
    
    # Training loop
    best_acc = 0
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
    
    # Load best model for inference
    print("\nLoading best model for inference...")
    model.load_state_dict(torch.load('best_model.pth'))
    
    # Inference on test set
    print("Generating predictions...")
    test_dataset = CellularDataset(test_df, Config.DATA_DIR, mode='test')
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, 
                            shuffle=False, num_workers=Config.NUM_WORKERS)
    
    model.eval()
    predictions = []
    ids = []
    
    with torch.no_grad():
        for imgs, img_ids in tqdm(test_loader, desc='Inference'):
            imgs = imgs.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            
            predictions.extend(preds.cpu().numpy())
            ids.extend(img_ids)
    
    # Convert predictions back to sirna format
    label_to_sirna = {idx: sirna for sirna, idx in sirna_to_label.items()}
    predictions_sirna = [label_to_sirna[pred] for pred in predictions]
    
    # Create submission
    submission = pd.DataFrame({
        'id_code': ids,
        'sirna': predictions_sirna
    })
    submission.to_csv('submission.csv', index=False)
    print("\nSubmission saved to submission.csv")
    print(f"Best validation accuracy: {best_acc:.2f}%")
    print(f"Sample predictions:")
    print(submission.head(10))

if __name__ == '__main__':
    main()

Loading data...
Training samples: 17689
Test samples: 8847
Cell types in train: ['HUVEC']
Number of unique sirnas: 1108
Label range: 0 to 1107
Creating model...


model.safetensors:   0%|          | 0.00/49.3M [00:00<?, ?B/s]


Epoch 1/40


Training:   0%|          | 0/470 [00:00<?, ?it/s]


KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pandas/core/indexes/base.py", line 3805, in get_loc
    return self._engine.get_loc(casted_key)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "index.pyx", line 167, in pandas._libs.index.IndexEngine.get_loc
  File "index.pyx", line 196, in pandas._libs.index.IndexEngine.get_loc
  File "pandas/_libs/hashtable_class_helper.pxi", line 7081, in pandas._libs.hashtable.PyObjectHashTable.get_item
  File "pandas/_libs/hashtable_class_helper.pxi", line 7089, in pandas._libs.hashtable.PyObjectHashTable.get_item
KeyError: 'site'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
            ~~~~~~~~~~~~^^^^^
  File "/tmp/ipykernel_48/1543366408.py", line 55, in __getitem__
    img = self.load_image(row)
          ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_48/1543366408.py", line 26, in load_image
    site = row['site']
           ~~~^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pandas/core/series.py", line 1121, in __getitem__
    return self._get_value(key)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pandas/core/series.py", line 1237, in _get_value
    loc = self.index.get_loc(label)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pandas/core/indexes/base.py", line 3812, in get_loc
    raise KeyError(key) from err
KeyError: 'site'


In [7]:
import pandas as pd

# Load the CSVs
DATA_DIR = '/kaggle/input/recursion-cellular-image-classification'
train_df = pd.read_csv(f'{DATA_DIR}/train.csv')
test_df = pd.read_csv(f'{DATA_DIR}/test.csv')

# Print info about the dataframes
print("TRAIN CSV INFO:")
print(train_df.head())
print("\nTrain columns:", train_df.columns.tolist())
print("Train shape:", train_df.shape)

print("\n" + "="*50)
print("\nTEST CSV INFO:")
print(test_df.head())
print("\nTest columns:", test_df.columns.tolist())
print("Test shape:", test_df.shape)

# Check unique values for some columns
if 'experiment' in train_df.columns:
    print("\nUnique experiments:", train_df['experiment'].unique())
if 'plate' in train_df.columns:
    print("Unique plates:", train_df['plate'].nunique())

TRAIN CSV INFO:
          id_code experiment  plate well       sirna
0  HEPG2-01_1_B03   HEPG2-01      1  B03   sirna_250
1  HEPG2-01_1_B04   HEPG2-01      1  B04    sirna_62
2  HEPG2-01_1_B05   HEPG2-01      1  B05  sirna_1115
3  HEPG2-01_1_B06   HEPG2-01      1  B06   sirna_602
4  HEPG2-01_1_B07   HEPG2-01      1  B07   sirna_529

Train columns: ['id_code', 'experiment', 'plate', 'well', 'sirna']
Train shape: (36517, 5)


TEST CSV INFO:
          id_code experiment  plate well
0  HEPG2-08_1_B03   HEPG2-08      1  B03
1  HEPG2-08_1_B04   HEPG2-08      1  B04
2  HEPG2-08_1_B05   HEPG2-08      1  B05
3  HEPG2-08_1_B06   HEPG2-08      1  B06
4  HEPG2-08_1_B07   HEPG2-08      1  B07

Test columns: ['id_code', 'experiment', 'plate', 'well']
Test shape: (19899, 4)

Unique experiments: ['HEPG2-01' 'HEPG2-02' 'HEPG2-03' 'HEPG2-04' 'HEPG2-05' 'HEPG2-06'
 'HEPG2-07' 'HUVEC-01' 'HUVEC-02' 'HUVEC-03' 'HUVEC-04' 'HUVEC-05'
 'HUVEC-06' 'HUVEC-07' 'HUVEC-08' 'HUVEC-09' 'HUVEC-10' 'HUVEC-11'
 'HUVEC-

In [8]:
import os
import glob

# Check one example to see the file structure
sample = train_df.iloc[0]
exp = sample['experiment']
plate = sample['plate']
well = sample['well']

path = f'/kaggle/input/recursion-cellular-image-classification/train/{exp}/Plate{plate}/'
files = sorted(glob.glob(f'{path}{well}*.png'))
print(f"Files for {well}:")
for f in files[:10]:  # Show first 10
    print(f)

Files for B03:
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w1.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w2.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w3.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w4.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w5.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s1_w6.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s2_w1.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s2_w2.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s2_w3.png
/kaggle/input/recursion-cellular-image-classification/train/HEPG2-01/Plate1/B03_s2_w4.png


In [9]:
# Recursion Cellular Image Classification - Quick Solution
# Optimized for time constraints (< 10 hours)

import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import cv2
from PIL import Image

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration
class Config:
    # Paths (adjust these to your Kaggle paths)
    DATA_DIR = '/kaggle/input/recursion-cellular-image-classification'
    TRAIN_CSV = f'{DATA_DIR}/train.csv'
    TEST_CSV = f'{DATA_DIR}/test.csv'
    
    # Model settings
    MODEL_NAME = 'efficientnet_b3'  # Fast and accurate
    IMG_SIZE = 320  # Reduced from 512 for speed
    BATCH_SIZE = 32  # Adjust based on GPU memory
    EPOCHS = 40  # Reduced for time
    LR = 3e-4
    
    # Training settings
    NUM_WORKERS = 2
    SEED = 42
    NUM_CLASSES = 1108
    
    # Use only HUVEC cell type for speed (you can add more if time permits)
    CELL_TYPES = ['HUVEC']  # Add 'RPE', 'HEPG2', 'U2OS' if you have time
    
    # sirna needs to be converted to numeric labels
    CONVERT_SIRNA = True

# Set seed
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed(Config.SEED)

# Dataset class
class CellularDataset(Dataset):
    def __init__(self, df, data_dir, mode='train', transform=None):
        self.df = df
        self.data_dir = data_dir
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def load_image(self, row):
        """Load 6-channel image"""
        exp = row['experiment']
        plate = row['plate']
        well = row['well']
        
        # Extract site from id_code (format: CELLTYPE-XX_PLATE_WELL_siteN)
        # For train: HEPG2-01_1_B03 -> need to find the image files
        # The site information is in the actual filename, not the CSV
        # We need to try site 1 and site 2
        
        if self.mode == 'train':
            path_template = f'{self.data_dir}/train/{exp}/Plate{plate}/{well}_s1_w'
        else:
            path_template = f'{self.data_dir}/test/{exp}/Plate{plate}/{well}_s1_w'
        
        # Load all 6 channels
        channels = []
        for i in range(1, 7):
            img_path = f'{path_template}{i}.png'
            if os.path.exists(img_path):
                img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
                channels.append(img)
            else:
                # Fallback if file doesn't exist
                channels.append(np.zeros((512, 512), dtype=np.uint8))
        
        # Stack channels and resize
        img = np.stack(channels, axis=-1)
        img = cv2.resize(img, (Config.IMG_SIZE, Config.IMG_SIZE))
        
        return img
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = self.load_image(row)
        
        # Normalize to [0, 1]
        img = img.astype(np.float32) / 255.0
        
        if self.transform:
            # Convert to PIL for transforms (handle 6 channels)
            img = torch.from_numpy(img).permute(2, 0, 1)  # C, H, W
        else:
            img = torch.from_numpy(img).permute(2, 0, 1)
        
        if self.mode == 'train':
            label = row['label']
            return img, label
        else:
            return img, row['id_code']

# Model
class CellularModel(nn.Module):
    def __init__(self, model_name, num_classes, in_channels=6):
        super().__init__()
        # Load pretrained model
        self.backbone = timm.create_model(model_name, pretrained=True, in_chans=3)
        
        # Modify first conv layer to accept 6 channels
        if hasattr(self.backbone, 'conv_stem'):
            old_conv = self.backbone.conv_stem
            self.backbone.conv_stem = nn.Conv2d(
                in_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=False
            )
            # Initialize with average of pretrained weights
            with torch.no_grad():
                self.backbone.conv_stem.weight[:, :3] = old_conv.weight
                self.backbone.conv_stem.weight[:, 3:] = old_conv.weight
        
        # Get number of features
        n_features = self.backbone.get_classifier().in_features
        self.backbone.reset_classifier(0)  # Remove classifier
        
        # Custom classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(n_features, num_classes)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

# Training function
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training')
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        pbar.set_postfix({'loss': running_loss/len(loader), 'acc': 100.*correct/total})
    
    return running_loss/len(loader), 100.*correct/total

# Validation function
def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc='Validation'):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return running_loss/len(loader), 100.*correct/total

# Main training pipeline
def main():
    # Load data
    print("Loading data...")
    train_df = pd.read_csv(Config.TRAIN_CSV)
    test_df = pd.read_csv(Config.TEST_CSV)
    
    # Extract cell type from experiment column
    train_df['cell_type'] = train_df['experiment'].str.split('-').str[0]
    test_df['cell_type'] = test_df['experiment'].str.split('-').str[0]
    
    # Filter by cell type for speed
    train_df = train_df[train_df['cell_type'].isin(Config.CELL_TYPES)].reset_index(drop=True)
    test_df = test_df[test_df['cell_type'].isin(Config.CELL_TYPES)].reset_index(drop=True)
    
    print(f"Training samples: {len(train_df)}")
    print(f"Test samples: {len(test_df)}")
    print(f"Cell types in train: {train_df['cell_type'].unique()}")
    
    # Convert sirna labels to numeric (sirna_1 -> 1, sirna_10 -> 10, etc.)
    train_df['sirna_id'] = train_df['sirna'].str.replace('sirna_', '').astype(int)
    
    # Create label mapping (need to map to 0-indexed consecutive integers)
    unique_sirnas = sorted(train_df['sirna_id'].unique())
    sirna_to_label = {sirna: idx for idx, sirna in enumerate(unique_sirnas)}
    train_df['label'] = train_df['sirna_id'].map(sirna_to_label)
    
    print(f"Number of unique sirnas: {len(unique_sirnas)}")
    print(f"Label range: 0 to {train_df['label'].max()}")
    
    # Update NUM_CLASSES based on actual data
    Config.NUM_CLASSES = len(unique_sirnas)
    
    # Split train/val
    train_data, val_data = train_test_split(train_df, test_size=0.15, random_state=Config.SEED, stratify=train_df['label'])
    
    # Create datasets
    train_dataset = CellularDataset(train_data, Config.DATA_DIR, mode='train')
    val_dataset = CellularDataset(val_data, Config.DATA_DIR, mode='train')
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, 
                             shuffle=True, num_workers=Config.NUM_WORKERS, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, 
                           shuffle=False, num_workers=Config.NUM_WORKERS, pin_memory=True)
    
    # Create model
    print("Creating model...")
    model = CellularModel(Config.MODEL_NAME, Config.NUM_CLASSES).to(device)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=Config.LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=Config.EPOCHS)
    
    # Training loop
    best_acc = 0
    for epoch in range(Config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{Config.EPOCHS}")
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"Saved best model with accuracy: {best_acc:.2f}%")
    
    # Load best model for inference
    print("\nLoading best model for inference...")
    model.load_state_dict(torch.load('best_model.pth'))
    
    # Inference on test set
    print("Generating predictions...")
    test_dataset = CellularDataset(test_df, Config.DATA_DIR, mode='test')
    test_loader = DataLoader(test_dataset, batch_size=Config.BATCH_SIZE, 
                            shuffle=False, num_workers=Config.NUM_WORKERS)
    
    model.eval()
    predictions = []
    ids = []
    
    with torch.no_grad():
        for imgs, img_ids in tqdm(test_loader, desc='Inference'):
            imgs = imgs.to(device)
            outputs = model(imgs)
            _, preds = outputs.max(1)
            
            predictions.extend(preds.cpu().numpy())
            ids.extend(img_ids)
    
    # Convert predictions back to sirna format
    label_to_sirna = {idx: sirna for sirna, idx in sirna_to_label.items()}
    predictions_sirna = [label_to_sirna[pred] for pred in predictions]
    
    # Create submission
    submission = pd.DataFrame({
        'id_code': ids,
        'sirna': predictions_sirna
    })
    submission.to_csv('submission.csv', index=False)
    print("\nSubmission saved to submission.csv")
    print(f"Best validation accuracy: {best_acc:.2f}%")
    print(f"Sample predictions:")
    print(submission.head(10))

if __name__ == '__main__':
    main()

Using device: cuda
Loading data...
Training samples: 17689
Test samples: 8847
Cell types in train: ['HUVEC']
Number of unique sirnas: 1108
Label range: 0 to 1107
Creating model...

Epoch 1/40


Training: 100%|██████████| 470/470 [12:53<00:00,  1.65s/it, loss=6.4, acc=1.84]   
Validation: 100%|██████████| 83/83 [02:18<00:00,  1.67s/it]


Train Loss: 6.4036, Train Acc: 1.84%
Val Loss: 5.1270, Val Acc: 8.21%
Saved best model with accuracy: 8.21%

Epoch 2/40


Training: 100%|██████████| 470/470 [06:11<00:00,  1.27it/s, loss=4.32, acc=14.9]
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.31it/s]


Train Loss: 4.3219, Train Acc: 14.91%
Val Loss: 3.5653, Val Acc: 25.28%
Saved best model with accuracy: 25.28%

Epoch 3/40


Training: 100%|██████████| 470/470 [05:58<00:00,  1.31it/s, loss=2.75, acc=39.1] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.33it/s]


Train Loss: 2.7543, Train Acc: 39.08%
Val Loss: 2.9472, Val Acc: 36.10%
Saved best model with accuracy: 36.10%

Epoch 4/40


Training: 100%|██████████| 470/470 [06:07<00:00,  1.28it/s, loss=1.66, acc=63.4] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Train Loss: 1.6623, Train Acc: 63.39%
Val Loss: 2.7081, Val Acc: 41.18%
Saved best model with accuracy: 41.18%

Epoch 5/40


Training: 100%|██████████| 470/470 [06:01<00:00,  1.30it/s, loss=0.955, acc=79.9]
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.35it/s]


Train Loss: 0.9552, Train Acc: 79.93%
Val Loss: 2.6540, Val Acc: 44.88%
Saved best model with accuracy: 44.88%

Epoch 6/40


Training: 100%|██████████| 470/470 [05:59<00:00,  1.31it/s, loss=0.531, acc=89.5]
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Train Loss: 0.5310, Train Acc: 89.52%
Val Loss: 2.7588, Val Acc: 43.78%

Epoch 7/40


Training: 100%|██████████| 470/470 [06:17<00:00,  1.24it/s, loss=0.274, acc=95]   
Validation: 100%|██████████| 83/83 [01:05<00:00,  1.27it/s]


Train Loss: 0.2736, Train Acc: 95.00%
Val Loss: 2.8202, Val Acc: 45.74%
Saved best model with accuracy: 45.74%

Epoch 8/40


Training: 100%|██████████| 470/470 [06:08<00:00,  1.27it/s, loss=0.156, acc=97.6] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.34it/s]


Train Loss: 0.1560, Train Acc: 97.61%
Val Loss: 2.8313, Val Acc: 45.21%

Epoch 9/40


Training: 100%|██████████| 470/470 [06:11<00:00,  1.27it/s, loss=0.0988, acc=98.5]
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.36it/s]


Train Loss: 0.0988, Train Acc: 98.45%
Val Loss: 2.9313, Val Acc: 45.59%

Epoch 10/40


Training: 100%|██████████| 470/470 [06:07<00:00,  1.28it/s, loss=0.0692, acc=99.1]
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.35it/s]


Train Loss: 0.0692, Train Acc: 99.12%
Val Loss: 3.0527, Val Acc: 45.74%

Epoch 11/40


Training: 100%|██████████| 470/470 [06:04<00:00,  1.29it/s, loss=0.0718, acc=98.9]
Validation: 100%|██████████| 83/83 [01:00<00:00,  1.36it/s]


Train Loss: 0.0718, Train Acc: 98.86%
Val Loss: 2.9944, Val Acc: 44.50%

Epoch 12/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.29it/s, loss=0.07, acc=98.8]  
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.34it/s]


Train Loss: 0.0700, Train Acc: 98.84%
Val Loss: 3.1608, Val Acc: 44.39%

Epoch 13/40


Training: 100%|██████████| 470/470 [06:10<00:00,  1.27it/s, loss=0.072, acc=98.6] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Train Loss: 0.0720, Train Acc: 98.60%
Val Loss: 3.1729, Val Acc: 44.80%

Epoch 14/40


Training: 100%|██████████| 470/470 [06:12<00:00,  1.26it/s, loss=0.0539, acc=99.2]
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.31it/s]


Train Loss: 0.0539, Train Acc: 99.18%
Val Loss: 3.2002, Val Acc: 45.93%
Saved best model with accuracy: 45.93%

Epoch 15/40


Training: 100%|██████████| 470/470 [06:04<00:00,  1.29it/s, loss=0.05, acc=99.1]   
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.35it/s]


Train Loss: 0.0500, Train Acc: 99.08%
Val Loss: 3.3098, Val Acc: 44.80%

Epoch 16/40


Training: 100%|██████████| 470/470 [05:55<00:00,  1.32it/s, loss=0.037, acc=99.3]  
Validation: 100%|██████████| 83/83 [01:00<00:00,  1.36it/s]


Train Loss: 0.0370, Train Acc: 99.32%
Val Loss: 3.2328, Val Acc: 44.84%

Epoch 17/40


Training: 100%|██████████| 470/470 [05:58<00:00,  1.31it/s, loss=0.0266, acc=99.6] 
Validation: 100%|██████████| 83/83 [00:59<00:00,  1.41it/s]


Train Loss: 0.0266, Train Acc: 99.63%
Val Loss: 3.2797, Val Acc: 46.42%
Saved best model with accuracy: 46.42%

Epoch 18/40


Training: 100%|██████████| 470/470 [06:01<00:00,  1.30it/s, loss=0.0245, acc=99.6] 
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.34it/s]


Train Loss: 0.0245, Train Acc: 99.63%
Val Loss: 3.3621, Val Acc: 45.74%

Epoch 19/40


Training: 100%|██████████| 470/470 [06:15<00:00,  1.25it/s, loss=0.0215, acc=99.7] 
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.34it/s]


Train Loss: 0.0215, Train Acc: 99.67%
Val Loss: 3.3433, Val Acc: 45.21%

Epoch 20/40


Training: 100%|██████████| 470/470 [06:01<00:00,  1.30it/s, loss=0.0214, acc=99.6] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.33it/s]


Train Loss: 0.0214, Train Acc: 99.64%
Val Loss: 3.4157, Val Acc: 46.31%

Epoch 21/40


Training: 100%|██████████| 470/470 [05:48<00:00,  1.35it/s, loss=0.0192, acc=99.7] 
Validation: 100%|██████████| 83/83 [01:00<00:00,  1.38it/s]


Train Loss: 0.0192, Train Acc: 99.67%
Val Loss: 3.2663, Val Acc: 47.10%
Saved best model with accuracy: 47.10%

Epoch 22/40


Training: 100%|██████████| 470/470 [05:53<00:00,  1.33it/s, loss=0.012, acc=99.9]  
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.36it/s]


Train Loss: 0.0120, Train Acc: 99.87%
Val Loss: 3.2943, Val Acc: 47.17%
Saved best model with accuracy: 47.17%

Epoch 23/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.28it/s, loss=0.00825, acc=99.9]
Validation: 100%|██████████| 83/83 [01:04<00:00,  1.29it/s]


Train Loss: 0.0083, Train Acc: 99.90%
Val Loss: 3.2844, Val Acc: 47.81%
Saved best model with accuracy: 47.81%

Epoch 24/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.29it/s, loss=0.00673, acc=99.9]
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.35it/s]


Train Loss: 0.0067, Train Acc: 99.93%
Val Loss: 3.2730, Val Acc: 46.76%

Epoch 25/40


Training: 100%|██████████| 470/470 [06:08<00:00,  1.28it/s, loss=0.00439, acc=100] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Train Loss: 0.0044, Train Acc: 99.95%
Val Loss: 3.2601, Val Acc: 47.89%
Saved best model with accuracy: 47.89%

Epoch 26/40


Training: 100%|██████████| 470/470 [06:13<00:00,  1.26it/s, loss=0.00419, acc=99.9]
Validation: 100%|██████████| 83/83 [01:05<00:00,  1.27it/s]


Train Loss: 0.0042, Train Acc: 99.94%
Val Loss: 3.2855, Val Acc: 47.93%
Saved best model with accuracy: 47.93%

Epoch 27/40


Training: 100%|██████████| 470/470 [06:08<00:00,  1.28it/s, loss=0.00407, acc=100] 
Validation: 100%|██████████| 83/83 [01:07<00:00,  1.23it/s]


Train Loss: 0.0041, Train Acc: 99.96%
Val Loss: 3.3401, Val Acc: 47.78%

Epoch 28/40


Training: 100%|██████████| 470/470 [06:06<00:00,  1.28it/s, loss=0.00259, acc=100] 
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.32it/s]


Train Loss: 0.0026, Train Acc: 99.99%
Val Loss: 3.2853, Val Acc: 47.70%

Epoch 29/40


Training: 100%|██████████| 470/470 [05:55<00:00,  1.32it/s, loss=0.00254, acc=100] 
Validation: 100%|██████████| 83/83 [01:01<00:00,  1.36it/s]


Train Loss: 0.0025, Train Acc: 99.98%
Val Loss: 3.3068, Val Acc: 47.59%

Epoch 30/40


Training: 100%|██████████| 470/470 [05:51<00:00,  1.34it/s, loss=0.00204, acc=100] 
Validation: 100%|██████████| 83/83 [00:59<00:00,  1.39it/s]


Train Loss: 0.0020, Train Acc: 99.99%
Val Loss: 3.3277, Val Acc: 48.46%
Saved best model with accuracy: 48.46%

Epoch 31/40


Training: 100%|██████████| 470/470 [05:51<00:00,  1.34it/s, loss=0.00196, acc=100] 
Validation: 100%|██████████| 83/83 [01:00<00:00,  1.36it/s]


Train Loss: 0.0020, Train Acc: 99.98%
Val Loss: 3.2972, Val Acc: 47.59%

Epoch 32/40


Training: 100%|██████████| 470/470 [05:53<00:00,  1.33it/s, loss=0.00151, acc=100]]
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.30it/s]


Train Loss: 0.0015, Train Acc: 100.00%
Val Loss: 3.2630, Val Acc: 48.38%

Epoch 33/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.29it/s, loss=0.00106, acc=100] 
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.30it/s]


Train Loss: 0.0011, Train Acc: 100.00%
Val Loss: 3.2688, Val Acc: 48.15%

Epoch 34/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.28it/s, loss=0.00091, acc=100] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.33it/s]


Train Loss: 0.0009, Train Acc: 100.00%
Val Loss: 3.2603, Val Acc: 48.57%
Saved best model with accuracy: 48.57%

Epoch 35/40


Validation: 100%|██████████| 83/83 [01:01<00:00,  1.35it/s] loss=0.000663, acc=100]


Train Loss: 0.0008, Train Acc: 100.00%
Val Loss: 3.2631, Val Acc: 49.47%
Saved best model with accuracy: 49.47%

Epoch 36/40


Training: 100%|██████████| 470/470 [06:12<00:00,  1.26it/s, loss=0.00104, acc=100] 
Validation: 100%|██████████| 83/83 [01:02<00:00,  1.32it/s]


Train Loss: 0.0010, Train Acc: 99.99%
Val Loss: 3.2373, Val Acc: 48.83%

Epoch 37/40


Training: 100%|██████████| 470/470 [06:07<00:00,  1.28it/s, loss=0.000796, acc=100]
Validation: 100%|██████████| 83/83 [01:05<00:00,  1.27it/s]


Train Loss: 0.0008, Train Acc: 100.00%
Val Loss: 3.2375, Val Acc: 49.17%

Epoch 38/40


Training: 100%|██████████| 470/470 [06:14<00:00,  1.25it/s, loss=0.000789, acc=100]
Validation: 100%|██████████| 83/83 [01:04<00:00,  1.28it/s]


Train Loss: 0.0008, Train Acc: 99.99%
Val Loss: 3.2795, Val Acc: 49.25%

Epoch 39/40


Training: 100%|██████████| 470/470 [06:05<00:00,  1.28it/s, loss=0.000795, acc=100]
Validation: 100%|██████████| 83/83 [01:03<00:00,  1.31it/s]


Train Loss: 0.0008, Train Acc: 100.00%
Val Loss: 3.2378, Val Acc: 49.17%

Epoch 40/40


Training: 100%|██████████| 470/470 [06:13<00:00,  1.26it/s, loss=0.000763, acc=100]
Validation: 100%|██████████| 83/83 [01:06<00:00,  1.24it/s]


Train Loss: 0.0008, Train Acc: 100.00%
Val Loss: 3.2467, Val Acc: 49.21%

Loading best model for inference...
Generating predictions...


Inference: 100%|██████████| 277/277 [08:26<00:00,  1.83s/it]


Submission saved to submission.csv
Best validation accuracy: 49.47%
Sample predictions:
          id_code  sirna
0  HUVEC-17_1_B03    671
1  HUVEC-17_1_B04    225
2  HUVEC-17_1_B05     54
3  HUVEC-17_1_B06    276
4  HUVEC-17_1_B07    564
5  HUVEC-17_1_B08     72
6  HUVEC-17_1_B09    767
7  HUVEC-17_1_B10   1015
8  HUVEC-17_1_B11    292
9  HUVEC-17_1_B12   1078



