In [1]:
import numpy as np
import seaborn as sea
import matplotlib.pyplot as plt
import pandas as pd
import random 
import os
from sklearn.model_selection import train_test_split
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings("ignore")

In [2]:
ROOT = "/kaggle/input/finalized-astrovision-data"

SEED = 3126

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EPOCHS = 100

LEARNING_RATE = .001

BATCH_SIZE = 32

TARGET_MAP = {
    "Ellipticals":0,
    "Irregulars":1,
    "Lenticulars":2,
    "Spirals":3
}

In [3]:
def set_seed(seed):
    random.seed(seed)               # Python's built-in random module
    np.random.seed(seed)            # NumPy
    torch.manual_seed(seed)         # PyTorch CPU
    torch.cuda.manual_seed(seed)    # PyTorch GPU
    torch.cuda.manual_seed_all(seed) # For all GPUs (if applicable)

    # Ensures deterministic behavior in some operations
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init_fn(worker_id):
    np.random.seed(SEED + worker_id)
    random.seed(SEED + worker_id)
    
set_seed(SEED)

In [4]:
means = [0.1431, 0.1285, 0.1565]
stds = [0.1314, 0.1237, 0.1443]

train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(.5),
    transforms.RandomVerticalFlip(.5),
    transforms.RandomRotation(180),
    transforms.Normalize(mean=means, std=stds)
])

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=means, std=stds)
])

In [5]:
class GalaxyData(Dataset):
    def __init__(self, ROOT, transforms=None):
        self.ROOT = ROOT
        self.transforms = transforms
        self.imgs = []
        self.labels = []

        for category in TARGET_MAP.keys():
            category_path = os.path.join(ROOT, category)
            img_files = os.listdir(category_path)
            for img_file in img_files:
                img_path = os.path.join(category_path, img_file)
                self.imgs.append(img_path)
                self.labels.append(TARGET_MAP[category])
                    
    def __len__(self):
        return len(self.imgs)
        
    def __getitem__(self, idx):
        img   = Image.open(self.imgs[idx]).convert("RGB")
        label = self.labels[idx]
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label

In [6]:
dataset = GalaxyData(ROOT, None)

test_fold, train_fold = train_test_split(
    range(len(dataset)),
    test_size=.8,
    shuffle=True,
    random_state=SEED
)

## Separates train data
train_dataset = GalaxyData(ROOT, train_transforms)
train_dataset = Subset(train_dataset, train_fold)
## Separates validation data
val_dataset   = GalaxyData(ROOT, test_transforms)
val_dataset   = Subset(val_dataset, test_fold)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
                          num_workers=4, worker_init_fn=worker_init_fn)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
                          num_workers=4, worker_init_fn=worker_init_fn)

In [7]:
class EarlyStopping():
    def __init__(self, early_stopping_rounds=5, min_loss_change=0.001, verbose=False):
        """
        An implementation of early stopping to restrain overfitting in the model. 
        
            early_stopping_rounds (int): Number of epochs to wait after last improvement.
            min_score_change (float): Minimum change in monitored metric to qualify as improvement.
            verbose (bool): Whether to print early stopping messages.
        """
        self.early_stopping_rounds = early_stopping_rounds
        self.min_loss_change = min_loss_change
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, model, val_loss):
        ## Improvement in loss has occurred; resets counter and saves current model
        if val_loss < self.best_loss - self.min_loss_change:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), 'best_model.pth')  
        ## No improvement in loss occurred; increases counter
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: [{self.counter}/{self.early_stopping_rounds}]\n")
            if self.counter >= self.early_stopping_rounds:
                self.early_stop = True

In [8]:
def validate(model, loss_fn, val_loader):
    model.eval()
    total_loss, correct, total_samples = 0.0, 0.0, 0.0

    with torch.inference_mode():
        for imgs,labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
    
            y_preds = model(imgs)
            loss = loss_fn(y_preds, labels.long())
            total_loss += loss.item()
            
            total_samples += labels.size(0)
            preds = torch.argmax(y_preds, dim=1)
            correct += torch.sum(preds==labels).item()

        accuracy = correct/total_samples
        avg_loss = total_loss/len(val_loader)
    return accuracy, avg_loss

In [9]:
def train(model, optimizer, loss_fn, train_loader):
    model.train()
    total_loss, correct, total_samples = 0.0, 0.0, 0.0
        
    for imgs,labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        ## Predicts logit
        y_preds = model(imgs)
        ## Computes loss
        loss = loss_fn(y_preds, labels.long())
        total_loss += loss.item()
        ## Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_samples += labels.size(0)
        preds = torch.argmax(y_preds, dim=1)
        correct += torch.sum(preds==labels).item()
        
    accuracy = correct/total_samples
    avg_loss = total_loss/len(train_loader)
    return accuracy, avg_loss

In [10]:
def train_validate_loop(epochs, model, optimizer, loss_fn, 
                        train_loader, val_loader, 
                        early_stopping_rounds=5, lr_schedule=None):
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    val_losses = []

    history = {}
    early_stopping = EarlyStopping(early_stopping_rounds=early_stopping_rounds,
                                   min_loss_change=0.001,
                                   verbose=True)
    
    for epoch in range(1, epochs+1):
        ### TRAINING PHASE
        train_accuracy, train_loss = train(model, optimizer, loss_fn, train_loader)
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss)
        print(f"Training Epoch [{epoch}/{epochs}] - Accuracy: {train_accuracy:.4f} | Loss: {train_loss:.4f}")

        ### VALIDATION PHASE
        val_accuracy, val_loss = validate(model, loss_fn, val_loader)
        val_accuracies.append(val_accuracy)
        val_losses.append(val_loss)
        print(f"Validation - Accuracy: {val_accuracy:.4f} | Loss: {val_loss:.4f}\n")

        ### EARLY STOPPING CHECK
        early_stopping(model, val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch}.")
            break

        ### LR SCHEDULER STEPPING
        if lr_schedule is not None: lr_schedule.step()
            
    history["train_accuracies"] = train_accuracies
    history["val_accuracies"]   = val_accuracies
    history["train_losses"] = train_losses
    history["val_losses"]   = val_losses
    return history

In [11]:
class CoordinateAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(CoordinateAttention, self).__init__()
        self.in_channels = in_channels
        self.reduction   = reduction
        
        # avg_pool_x: (B, C, H, 1)
        # avg_pool_y: (B, C, 1, W)
        self.avg_pool_x = nn.AdaptiveAvgPool2d((None, 1))
        self.avg_pool_y = nn.AdaptiveAvgPool2d((1, None))
        # Downsampling
        self.conv1  = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)
        self.bn     = nn.BatchNorm2d(in_channels // reduction)
        # Upsampling to size of input
        self.conv_x = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=True)
        self.conv_y = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=True)
        self.init_modules_weights()
        
    def init_modules_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, .25) 
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.ones_(module.weight)
                nn.init.constant_(module.bias, 0.25)
                
    def forward(self, x):
        B, C, H, W = x.shape
        x_avg = self.avg_pool_x(x)  # (B, C, H, 1)
        y_avg = self.avg_pool_y(x).permute(0, 1, 3, 2)  # (B, C, 1, W) -> (B, C, W, 1)
        
        concat    = torch.cat([x_avg, y_avg], dim=2)
        attention = F.relu(self.bn(self.conv1(concat)))
        # Split into spatial attentions
        attention_x, attention_y = torch.split(attention, [H, W], dim=2)
        attention_x = torch.sigmoid(self.conv_x(attention_x))
        attention_y = torch.sigmoid(self.conv_y(attention_y))
        # Reshape attention_y to match the original input dimensions
        attention_y = attention_y.permute(0, 1, 3, 2)  # (B, C, W, 1) -> (B, C, 1, W)
        
        return x * attention_x * attention_y

In [12]:
class MobileNetResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, 
                 expansion_factor=3,
                 downsample=False, 
                 use_attention=True, 
                 reduction=16
                ):
        super(MobileNetResBlock, self).__init__()
        stride = 2 if downsample else 1
        expanded_channels = out_channels * expansion_factor
        
        ## Implements attention mechanism
        self.attention1 = nn.Sequential()
        self.attention2 = nn.Sequential()
        self.attention3 = nn.Sequential()
        if use_attention:
            self.attention1 = CoordinateAttention(expanded_channels, reduction)
            self.attention2 = CoordinateAttention(expanded_channels, reduction)
            self.attention3 = CoordinateAttention(out_channels, reduction)
        ## Inverted Bottleneck Structure
        self.features = nn.Sequential(
            nn.Conv2d(in_channels, expanded_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(expanded_channels),
            nn.ReLU(inplace=True),
            self.attention1,
            nn.Conv2d(expanded_channels, expanded_channels, kernel_size=3, padding=1, bias=False, groups=expanded_channels),
            nn.BatchNorm2d(expanded_channels),
            self.attention2,
            nn.Conv2d(expanded_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
        )

        ## Projection process matches dimensions of identity to new features
        self.projection = nn.Sequential()
        if (downsample) or (in_channels!=out_channels):
            self.projection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )    
        self.init_modules_weights()
        
    def init_modules_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.25) 
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.ones_(module.weight)
                nn.init.constant_(module.bias, 0.25)
                
    def forward(self, x):
        identity = self.projection(x)
        x = self.features(x)
        return F.relu(self.attention3(x + identity))

In [13]:
class CustomMobileNetV3(nn.Module):
    def __init__(self,
                 expansion_factor=3,
                 res_out_channels=[112, 160, 256, 480], 
                 blocks_per_layer=[2, 3, 3, 2],
                 res_attention=[True, True, True, True],
                 reduction=16
                ):
        super(CustomMobileNetV3, self).__init__()
        self.res_in_channels = 80
        
        self.init_conv = nn.Sequential(
            # 3x224x224 -> 64x112x112
            MobileNetResBlock(3, 64, expansion_factor, downsample=True, use_attention=True, reduction=reduction),
            # 64x112x112 -> 64x112x112
            MobileNetResBlock(64, 64, expansion_factor, downsample=False, use_attention=False),
            # 64x112x112 -> 80x56x56
            MobileNetResBlock(64, 80, expansion_factor, downsample=True, use_attention=True, reduction=reduction),
            # 80x56x56 -> 80x56x56
            MobileNetResBlock(80, 80, expansion_factor, downsample=False, use_attention=False),
        )

         # 80x56x56 -> C1x56x56
        self.res_layer1 = self.make_res_layer(res_out_channels[0], blocks_per_layer[0], expansion_factor, False, res_attention[0], reduction)
        # C1x56x56 -> C2x28x28
        self.res_layer2 = self.make_res_layer(res_out_channels[1], blocks_per_layer[1], expansion_factor, True, res_attention[1], reduction)
        # C2x28x28 -> C3x14x14
        self.res_layer3 = self.make_res_layer(res_out_channels[2], blocks_per_layer[2], expansion_factor, True, res_attention[2], reduction)
        # C3x14x14 -> C4x7x7
        self.res_layer4 = self.make_res_layer(res_out_channels[3], blocks_per_layer[3], expansion_factor, True, res_attention[3], reduction)

        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        self.classifier = nn.Sequential(
            nn.Linear(res_out_channels[-1], 128, bias=False),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 32, bias=False),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 4, bias=True)
        )
        self.init_modules_weights()

    def make_res_layer(self, out_channels, blocks, expansion_factor, downsample=False, use_attention=True, reduction=16):
        ## Initializes first block in layer
        layers = [MobileNetResBlock(self.res_in_channels, out_channels, expansion_factor, downsample, use_attention, reduction)]
        self.res_in_channels = out_channels
        ## Attaches remaining blocks to layer with NO DOWNSAMPLING
        for _ in range(blocks-1):
            layers.append(MobileNetResBlock(out_channels, out_channels, expansion_factor, use_attention=use_attention, reduction=reduction))
        return nn.Sequential(*layers)

    def init_modules_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.25) 
            elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.constant_(module.bias, 0.25)
                
    def forward(self, x):
        x = self.init_conv(x)
        x = self.res_layer1(x)
        x = self.res_layer2(x)
        x = self.res_layer3(x)
        x = self.res_layer4(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [14]:
model = CustomMobileNetV3().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

history = train_validate_loop(
    EPOCHS, 
    model,
    optimizer, 
    loss_fn, 
    train_loader, 
    val_loader,
    early_stopping_rounds=10
)

Training Epoch [1/100] - Accuracy: 0.5713 | Loss: 0.9823
Validation - Accuracy: 0.6384 | Loss: 0.8732

Training Epoch [2/100] - Accuracy: 0.6403 | Loss: 0.8284
Validation - Accuracy: 0.6865 | Loss: 0.8132

Training Epoch [3/100] - Accuracy: 0.6754 | Loss: 0.7512
Validation - Accuracy: 0.7256 | Loss: 0.6569

Training Epoch [4/100] - Accuracy: 0.7069 | Loss: 0.6709
Validation - Accuracy: 0.7581 | Loss: 0.6521

Training Epoch [5/100] - Accuracy: 0.7315 | Loss: 0.6327
Validation - Accuracy: 0.7117 | Loss: 0.6639

EarlyStopping counter: [1/10]

Training Epoch [6/100] - Accuracy: 0.7340 | Loss: 0.6041
Validation - Accuracy: 0.7671 | Loss: 0.5812

Training Epoch [7/100] - Accuracy: 0.7590 | Loss: 0.5571
Validation - Accuracy: 0.7818 | Loss: 0.5119

Training Epoch [8/100] - Accuracy: 0.7659 | Loss: 0.5455
Validation - Accuracy: 0.7679 | Loss: 0.5288

EarlyStopping counter: [1/10]

Training Epoch [9/100] - Accuracy: 0.7737 | Loss: 0.5278
Validation - Accuracy: 0.7011 | Loss: 0.6961

EarlyStoppi