## Step 1: Import dependencies

In [1]:
import os
import math
import datetime
import random
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2
from tqdm.auto import tqdm
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from vit_pytorch.crossformer import CrossFormer
from torch.amp import autocast, GradScaler
from torchvision.transforms.v2.functional import InterpolationMode

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
num_workers = max((os.cpu_count() or 1) - 1, 0)
device, num_workers

(device(type='cuda'), 11)

## Step 2: Define dataset

In [2]:
class ParquetImageDataset(Dataset):
    """
    A PyTorch Dataset that reads a Parquet file containing:
      - img_path: path to image on disk
      - label (str): label of the image (e.g., 'spiral', 'smooth', etc.)
    """
    def __init__(self, parquet_file, transform=None, label_encoder=None):
        super().__init__()
        self.data = pd.read_parquet(parquet_file)
        self.transform = transform
        self.label_encoder = label_encoder

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_path = "dataset/"+row['img_path']
        label_str = row['label']

        image = Image.open(img_path).convert('RGB')

        if self.transform is not None:
            image = self.transform(image)

        if self.label_encoder is not None:
            label = self.label_encoder.transform([label_str])[0]
            label = torch.tensor(label, dtype=torch.long)
        else:
            label = label_str

        return image, label


In [3]:
TRAIN_DATASET_PATH = "dataset/train.parquet"
VALIDATION_DATASET_PATH = "dataset/validation.parquet"
TEST_DATASET_PATH = "dataset/test.parquet"

In [4]:
train_df = pd.read_parquet(TRAIN_DATASET_PATH)
val_df   = pd.read_parquet(VALIDATION_DATASET_PATH)
test_df  = pd.read_parquet(TEST_DATASET_PATH)

all_labels = pd.concat([train_df['label'], val_df['label'], test_df['label']])

le = LabelEncoder()
le.fit(all_labels)

print("Classes found:", le.classes_, len(le.classes_))

Classes found: ['barred_spiral' 'edge_on_disk' 'featured_without_bar_or_spiral'
 'irregular' 'smooth_cigar' 'smooth_inbetween' 'smooth_round'
 'unbarred_spiral'] 8


## Step 3: Define On-the-Fly Augmentation

In [5]:
train_transform = v2.Compose([
    # 1) Random crop/resize to simulate zoom & shifts
    v2.RandomResizedCrop(
        size=(224, 224),
        scale=(0.8, 1.0),   # random area: 80% ~ 100% of original
        ratio=(0.9, 1.1),  # aspect ratio range to allow slight stretching
        interpolation=InterpolationMode.BILINEAR
    ),
    
    # 2) Randomly rotate in full 360° range
    v2.RandomRotation(
        degrees=(-180, 180),
        interpolation=InterpolationMode.BILINEAR
    ),
    
    # 3) Random flips (helps orientation invariance)
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    
    # 4) Mild color jitter for brightness, contrast, saturation
    v2.ColorJitter(
        brightness=0.1,
        contrast=0.1,
        saturation=0.1,
        hue=0.0
    ),
    
    # 5) Randomly apply a slight Gaussian blur (~30% chance)
    v2.RandomApply(
        [v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))],
        p=0.3
    ),
    
    # 6) Convert to float tensor [0..1]
    v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])  # v2.ToTensor()

])

val_test_transform = v2.Compose([
    v2.Resize((224, 224)),
    v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)])  # v2.ToTensor()
])

## Step 4: Load Datasets

In [6]:
train_dataset = ParquetImageDataset(
    parquet_file=TRAIN_DATASET_PATH,
    transform=train_transform,
    label_encoder=le
)

validation_dataset = ParquetImageDataset(
    parquet_file=VALIDATION_DATASET_PATH,
    transform=val_test_transform,
    label_encoder=le
)

test_dataset = ParquetImageDataset(
    parquet_file=TEST_DATASET_PATH,
    transform=val_test_transform,
    label_encoder=le
)

print("Train train set size:", len(train_dataset))
print("Validation set size:", len(validation_dataset))
print("Test set size:", len(test_dataset))

Train train set size: 125575
Validation set size: 41859
Test set size: 41859


## 3. Class Weights + Sampler

In [7]:
class_counts = train_df['label'].value_counts()

class_counts

label
smooth_inbetween                  47179
smooth_round                      32944
unbarred_spiral                   15672
smooth_cigar                      11847
irregular                          7089
barred_spiral                      5483
edge_on_disk                       2774
featured_without_bar_or_spiral     2587
Name: count, dtype: int64

In [8]:
num_classes = len(le.classes_)
weight_array = np.zeros(num_classes, dtype=np.float32)

# Loop through each class label (as a string) in le.classes_, find how many samples it has, and compute 1 / frequency
for class_str in le.classes_:
    class_idx = le.transform([class_str])[0]
    count = class_counts[class_str]
    weight_array[class_idx] = 1.0 / count


In [9]:
class_weights_tensor = torch.tensor(weight_array, dtype=torch.float32).to(device)
class_weights_tensor

tensor([1.8238e-04, 3.6049e-04, 3.8655e-04, 1.4106e-04, 8.4410e-05, 2.1196e-05,
        3.0355e-05, 6.3808e-05], device='cuda:0')

In [10]:
batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

val_loader = DataLoader(
    validation_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

train_loader, val_loader, test_loader

(<torch.utils.data.dataloader.DataLoader at 0x7f5092864890>,
 <torch.utils.data.dataloader.DataLoader at 0x7f545ce0e540>,
 <torch.utils.data.dataloader.DataLoader at 0x7f519ac59a60>)

## Step 5: Initialize CrossFormer Model

In [11]:
model = CrossFormer(
    num_classes=len(le.classes_),
    dim=(32, 64, 128, 256),
    depth=(2, 2, 4, 2),
    global_window_size=(8, 4, 2, 1),
    local_window_size=7,
    attn_dropout=0.1,
    ff_dropout=0.1
).to(device)

## Step 7: Training Loop

In [12]:
### Load model from previous training:

checkpoint_path = "checkpoints/checkpoint_epoch_25_20250304_acc_0.7567.pth"
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [13]:
os.makedirs("checkpoints", exist_ok=True) # directory for checkpoints store

try:
    best_val_acc = checkpoint['val_acc'] 
except:
    best_val_acc = 0.0

num_epochs = 20
warmup_epochs = round(num_epochs * 0.1)
base_lr = 2e-6
weight_decay = 1e-4

def lr_lambda(epoch):
    if epoch < warmup_epochs:
        return float(epoch + 1) / warmup_epochs
    else:
        progress = (epoch - warmup_epochs) / (num_epochs - warmup_epochs)
        return 0.5 * (1.0 + math.cos(math.pi * progress))

scaler = GradScaler("cuda")
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
# optimizer = optim.AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
optimizer = optim.Adam(model.parameters(), lr=base_lr)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

print(f"Total number of epochs: {num_epochs}")
print(f"The number of warm up epochs: {warmup_epochs}")

Total number of epochs: 20
The number of warm up epochs: 2


In [14]:
for epoch in range(num_epochs):
    print(f"\nCurrent LR: {scheduler.get_last_lr()[0]:.2e}")
    
    ################################
    #          Training            #
    ################################
    model.train()
    running_train_loss = 0.0
    
    train_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", smoothing=0.9)
    
    for images, labels in train_tqdm:
        images, labels = images.to(device), labels.to(device)

        # Zero out the gradients
        optimizer.zero_grad()
        
        # 2) Mixed precision forward pass
        with autocast(device_type='cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)
        
        # 3) Backprop with scaled loss
        scaler.scale(loss).backward()
        
        # 4) Gradient clipping AFTER unscaling
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # 5) Step the optimizer with scaled gradients
        scaler.step(optimizer)
        scaler.update()
        
        running_train_loss += loss.item()
        train_tqdm.set_postfix(loss=loss.item())
        
    train_loss = running_train_loss / len(train_loader)

    ################################
    #         Validation           #
    ################################
    model.eval()
    running_val_loss = 0.0
    correct, total = 0, 0

    val_tqdm = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", smoothing=0.9)
    with torch.no_grad():
        for images, labels in val_tqdm:
            images, labels = images.to(device), labels.to(device)
            
            # Mixed precision inference also possible (faster on GPUs)
            with autocast(device_type='cuda'):
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            running_val_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total   += labels.size(0)

            val_tqdm.set_postfix(loss=loss.item())

    val_loss = running_val_loss / len(val_loader)
    val_acc = correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}] Summary: "
          f"Train Loss={train_loss:.4f}  |  Val Loss={val_loss:.4f}  |  Val Acc={val_acc:.4f}")
    
    # Checkpoint saving logic
    if ((epoch/num_epochs) <= 0.50 and (val_acc - best_val_acc) > 0.01) \
       or ((epoch/num_epochs) < 0.75 and (val_acc - best_val_acc) > 0.005) \
       or ((epoch/num_epochs) >= 0.75 and val_acc > best_val_acc):
        best_val_acc = val_acc
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_acc': val_acc
        }
        checkpoint_filename = f"checkpoints/{val_acc:.4f}_epoch_{epoch+1}_{datetime.datetime.today().strftime('%Y%m%d')}.pth"
        torch.save(checkpoint, checkpoint_filename)
        print(f"New best model saved: {checkpoint_filename}")
    
    scheduler.step()



Current LR: 1.00e-06


Epoch 1/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

  return _VF.einsum(equation, operands)  # type: ignore[attr-defined]


Epoch 1/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [1/20] Summary: Train Loss=0.8016  |  Val Loss=0.8158  |  Val Acc=0.7528

Current LR: 2.00e-06


Epoch 2/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 2/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [2/20] Summary: Train Loss=0.7954  |  Val Loss=0.8184  |  Val Acc=0.7511

Current LR: 2.00e-06


Epoch 3/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 3/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [3/20] Summary: Train Loss=0.7972  |  Val Loss=0.8119  |  Val Acc=0.7523

Current LR: 1.98e-06


Epoch 4/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 4/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060><function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060>
Exception ignored in: 
Traceback (most recent call last):
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060>Traceback (most recent call last):
Exception ignored in:   File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060>  File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__


    Traceback (most recent call last):
self._shutdown_workers()Exception ignored in:   File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Traceback (most recent call last):
    <function _MultiProcessingDataLoaderIter.

Epoch [4/20] Summary: Train Loss=0.7962  |  Val Loss=0.8120  |  Val Acc=0.7517

Current LR: 1.94e-06


Epoch 5/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 5/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [5/20] Summary: Train Loss=0.7957  |  Val Loss=0.8115  |  Val Acc=0.7550

Current LR: 1.87e-06


Epoch 6/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060>
Traceback (most recent call last):
  File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib64/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f51f754b060>
Traceback (most recent call last):
  File "/home/artursultanov/uni/cosmoformer/venv/lib64/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_

Epoch 6/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [6/20] Summary: Train Loss=0.7949  |  Val Loss=0.8113  |  Val Acc=0.7571

Current LR: 1.77e-06


Epoch 7/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 7/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [7/20] Summary: Train Loss=0.7959  |  Val Loss=0.8122  |  Val Acc=0.7514

Current LR: 1.64e-06


Epoch 8/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 8/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [8/20] Summary: Train Loss=0.7951  |  Val Loss=0.8115  |  Val Acc=0.7556

Current LR: 1.50e-06


Epoch 9/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 9/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [9/20] Summary: Train Loss=0.7943  |  Val Loss=0.8149  |  Val Acc=0.7542

Current LR: 1.34e-06


Epoch 10/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 10/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [10/20] Summary: Train Loss=0.7914  |  Val Loss=0.8103  |  Val Acc=0.7548

Current LR: 1.17e-06


Epoch 11/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 11/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [11/20] Summary: Train Loss=0.7929  |  Val Loss=0.8115  |  Val Acc=0.7545

Current LR: 1.00e-06


Epoch 12/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 12/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [12/20] Summary: Train Loss=0.7922  |  Val Loss=0.8115  |  Val Acc=0.7549

Current LR: 8.26e-07


Epoch 13/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 13/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [13/20] Summary: Train Loss=0.7925  |  Val Loss=0.8103  |  Val Acc=0.7554

Current LR: 6.58e-07


Epoch 14/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 14/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [14/20] Summary: Train Loss=0.7961  |  Val Loss=0.8114  |  Val Acc=0.7537

Current LR: 5.00e-07


Epoch 15/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 15/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [15/20] Summary: Train Loss=0.7923  |  Val Loss=0.8106  |  Val Acc=0.7557

Current LR: 3.57e-07


Epoch 16/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 16/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [16/20] Summary: Train Loss=0.7947  |  Val Loss=0.8111  |  Val Acc=0.7551

Current LR: 2.34e-07


Epoch 17/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 17/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [17/20] Summary: Train Loss=0.7939  |  Val Loss=0.8114  |  Val Acc=0.7545

Current LR: 1.34e-07


Epoch 18/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 18/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [18/20] Summary: Train Loss=0.7931  |  Val Loss=0.8111  |  Val Acc=0.7548

Current LR: 6.03e-08


Epoch 19/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 19/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [19/20] Summary: Train Loss=0.7923  |  Val Loss=0.8110  |  Val Acc=0.7549

Current LR: 1.52e-08


Epoch 20/20 [Train]:   0%|          | 0/1963 [00:00<?, ?it/s]

Epoch 20/20 [Val]:   0%|          | 0/655 [00:00<?, ?it/s]

Epoch [20/20] Summary: Train Loss=0.7932  |  Val Loss=0.8110  |  Val Acc=0.7548


## Step 8: Test Set Evaluation

In [15]:
model.eval()

test_loss = 0.0
correct = 0
total = 0

test_tqdm = tqdm(test_loader, desc="Testing", smoothing=0.9)

with torch.no_grad():
    for images, labels in test_tqdm:
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        
        _, predicted = torch.max(outputs, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        test_tqdm.set_postfix(loss=loss.item())

test_loss /= len(test_loader)
test_acc = correct / total

print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")

Testing:   0%|          | 0/655 [00:00<?, ?it/s]

Test Loss: 0.7960 | Test Acc: 0.7570


## Step 9: Single Prediction

In [16]:
# Example label mapping
label_mapping = {
    0: 'barred_spiral',
    1: 'edge_on_disk',
    2: 'featured_without_bar_or_spiral',
    3: 'irregular',
    4: 'smooth_cigar',
    5: 'smooth_inbetween',
    6: 'smooth_round',
    7: 'unbarred_spiral'
}

idx = random.randint(0, len(train_dataset) - 1)

example_image_tensor, example_label_int = train_dataset[idx]

example_image_tensor = example_image_tensor.unsqueeze(0).to(device)

# Run inference
model.eval()
with torch.no_grad():
    output = model(example_image_tensor)        # shape [1, num_classes]
    _, pred_idx = torch.max(output, dim=1)      # Get predicted class index
    pred_idx = pred_idx.item()                  # convert to int

real_truth_class_name = label_mapping[example_label_int.item()]
predicted_class_name = label_mapping[pred_idx]

print("Random sample index:", idx)
print("Predicted class index:", pred_idx)
print("Predicted class:", predicted_class_name)
print("Real truth class index:", example_label_int.item())
print("Real truth class:", real_truth_class_name)

Random sample index: 45289
Predicted class index: 5
Predicted class: smooth_inbetween
Real truth class index: 5
Real truth class: smooth_inbetween
