In [None]:
import numpy as np
import pandas as pd
import seaborn as sea
import matplotlib.pyplot as plt
import os
import random
from PIL import Image

from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset

import warnings
warnings.filterwarnings("ignore")

In [None]:
#ROOT = "Users\kdfer\Desktop\ACM Research\GalaxyImages"

ROOT = r"C:\Users\kdfer\Desktop\ACM Research\GalaxyImages"
SEED = 3126
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 100
LEARNING_RATE = .001
BATCH_SIZE = 64

TARGET_MAP = {
    "Ellipticals":0,
    "Irregulars":1,
    "Lenticulars":2,
    "Spirals":3
}
path = r"C:\Users\kdfer\Desktop\ACM Research\GalaxyImages\Ellipticals"
print(os.path.exists(path)) 

In [None]:
means = [0.1462, 0.1318, 0.1607]
stds  = [0.1302, 0.1230, 0.1437]

train_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(.5),
    transforms.RandomVerticalFlip(.5), 
    transforms.RandomAffine(degrees=0,
                            translate=(32/224, 32/224),
                            scale=(.85, 1.15),
                            interpolation=transforms.InterpolationMode.BICUBIC
                           ),
    transforms.ColorJitter(brightness=(.85, 1.35)),
    transforms.RandomRotation(350,
                              interpolation=transforms.InterpolationMode.BICUBIC
                             ),
    transforms.ToTensor(),
    transforms.Normalize(mean=means, std=stds)
])

test_transforms = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=means, std=stds)
])

In [None]:
# this gets all the data
import os
from torch.utils.data import Dataset
from PIL import Image

class GalaxiesData(Dataset):
    def __init__(self, ROOT, transforms=None):
        self.ROOT = os.path.normpath(ROOT)  # Normalize path
        self.transforms = transforms
        self.imgs = []
        self.labels = []

        for category in TARGET_MAP.keys():
            category_path = os.path.join(self.ROOT, category)
            category_path = os.path.normpath(category_path)  # Normalize again

            if not os.path.exists(category_path):  # Check if path exists
                print(f"Warning: {category_path} does not exist.")
                continue

            img_names = os.listdir(category_path)
            for img_name in img_names:
                img_path = os.path.join(category_path, img_name)
                img_path = os.path.normpath(img_path)

                self.imgs.append(img_path)
                self.labels.append(TARGET_MAP[category])
                    
    def __len__(self):
        return len(self.imgs)
        
    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        print(f"Loading image: {img_path}")  # Debugging line

        img = Image.open(img_path).convert("RGB")
        label = self.labels[idx]
        if self.transforms is not None:
            img = self.transforms(img)
        return img, label



In [None]:
# define all the datasets and training/test stuff

def worker_init_fn(worker_id):
    seed = torch.initial_seed() % 2**32  # Generate unique seed
    np.random.seed(seed)  # Set NumPy random seed
    #print(f"Worker {worker_id} initialized with seed {seed}")

train_dataset = GalaxiesData(ROOT, transforms=train_transforms)
val_dataset   = GalaxiesData(ROOT, transforms=test_transforms)
test_dataset  = GalaxiesData(ROOT, transforms=test_transforms)

# Define data sizes
train_size = int(0.7 * len(train_dataset))    # 70% for training
val_size   = int(0.15 * len(train_dataset))   # 15% for validation
test_size  = len(train_dataset) - train_size - val_size  # 15% for testing

# Extract indices for training and validation+testing
train_indices, val_test_indices = train_test_split(
    range(len(train_dataset)), 
    test_size=(val_size + test_size), 
    shuffle=True, 
    random_state=SEED
)

# Extract validation and test indices
val_indices, test_indices = train_test_split(
    val_test_indices, 
    test_size=test_size / (val_size + test_size), 
    shuffle=True, 
    random_state=SEED
)

# Separate datasets using Subset
train_subset = Subset(train_dataset, train_indices)
val_subset   = Subset(val_dataset, val_indices)
test_subset  = Subset(test_dataset, test_indices)

# Create DataLoader for each set
train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, worker_init_fn=worker_init_fn)
val_loader   = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)
test_loader  = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, worker_init_fn=worker_init_fn)

In [None]:
class EarlyStopping():
    def __init__(self, early_stopping_rounds=5, min_loss_change=0.001, verbose=False):
        """
        An implementation of early stopping to restrain overfitting in the model. 
        
            early_stopping_rounds (int): Number of epochs to wait after last improvement.
            min_score_change (float): Minimum change in monitored metric to qualify as improvement.
            verbose (bool): Whether to print early stopping messages.
        """
        self.early_stopping_rounds = early_stopping_rounds
        self.min_loss_change = min_loss_change
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, model, val_loss):
        ## Improvement in loss has occurred; resets counter and saves current model
        if val_loss < self.best_loss - self.min_loss_change:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), 'best_model.pth')  
        ## No improvement in loss occurred; increases counter
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: [{self.counter}/{self.early_stopping_rounds}]\n")
            if self.counter >= self.early_stopping_rounds:
                self.early_stop = True

In [None]:
def train(model, optimizer, loss_fn, train_loader):
    model.train()
    total_loss, correct, total_samples = 0.0, 0.0, 0.0
        
    for imgs,labels in train_loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        ## Predicts logit
        y_preds = model(imgs)
        ## Computes loss
        loss = loss_fn(y_preds, labels.long())
        total_loss += loss.item()
        ## Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_samples += labels.size(0)
        preds = torch.argmax(y_preds, dim=1)
        correct += torch.sum(preds==labels).item()
        
    accuracy = correct/total_samples
    avg_loss = total_loss/len(train_loader)
    return accuracy, avg_loss

In [None]:
def validate(model, loss_fn, val_loader):
    model.eval()
    total_loss, correct, total_samples = 0.0, 0.0, 0.0

    with torch.inference_mode():
        for imgs,labels in val_loader:
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
    
            y_preds = model(imgs)
            loss = loss_fn(y_preds, labels.long())
            total_loss += loss.item()
            
            total_samples += labels.size(0)
            preds = torch.argmax(y_preds, dim=1)
            correct += torch.sum(preds==labels).item()

        accuracy = correct/total_samples
        avg_loss = total_loss/len(val_loader)
    return accuracy, avg_loss

In [None]:
def train_validate_loop(epochs, model, optimizer, loss_fn, 
                        train_loader, val_loader, 
                        early_stopping_rounds=5, lr_schedule=None):
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    val_losses = []

    history = {}
    early_stopping = EarlyStopping(early_stopping_rounds=early_stopping_rounds,
                                   min_loss_change=0.001,
                                   verbose=True)
    
    for epoch in range(1, epochs+1):
        ### TRAINING PHASE
        train_accuracy, train_loss = train(model, optimizer, loss_fn, train_loader)
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss)
        print(f"Training Epoch [{epoch}/{epochs}] - Accuracy: {train_accuracy:.4f} | Loss: {train_loss:.4f}")

        ### VALIDATION PHASE
        val_accuracy, val_loss = validate(model, loss_fn, val_loader)
        val_accuracies.append(val_accuracy)
        val_losses.append(val_loss)
        print(f"Validation - Accuracy: {val_accuracy:.4f} | Loss: {val_loss:.4f}\n")

        ### EARLY STOPPING CHECK
        early_stopping(model, val_loss)
        if early_stopping.early_stop:
            print(f"Early stopping triggered at epoch {epoch}.")
            break

        ### LR SCHEDULER STEPPING
        if lr_schedule is not None: lr_schedule.step()
            
    history["train_accuracies"] = train_accuracies
    history["val_accuracies"]   = val_accuracies
    history["train_losses"] = train_losses
    history["val_losses"]   = val_losses
    return history

In [None]:
# teacher model

class TransformerClassifier(nn.Module):
    def __init__(self, dropout=0.2, num_classes=4):
        super(TransformerClassifier, self).__init__()
        self.transformer = torch.hub.load("facebookresearch/hiera", 
                                          model="hiera_base_224", 
                                          pretrained=True, 
                                          checkpoint="mae_in1k_ft_in1k")
        self.head_features = 1000
        self.num_classes = num_classes
        
        ## Freeze transformer layers
        for param in self.transformer.parameters():
            param.requires_grad = False
            
        ## Function for creating linear blocks
        def linearblock(in_features, out_features):
            return nn.Sequential(
                nn.Linear(in_features, out_features, bias=False),
                nn.BatchNorm1d(out_features),
                nn.ReLU(inplace=True)
            )

        ## Intermediate hidden layers
        self.fc_layers = nn.Sequential(
            linearblock(self.head_features, 512),
            #nn.Dropout(dropout),
            linearblock(512, 256),
            #nn.Dropout(dropout),
            linearblock(256, 128),
            #nn.Dropout(dropout),
            linearblock(128, 8),
            nn.Dropout(dropout/2),
        )
                
        ## Final classification layer
        self.classifier = nn.Linear(8, num_classes, bias=True)
        ## Weight initialization
        #self.init_layers_weights_()
        
    def init_layers_weights_(self):
        for module in self.modules():
            ## Kaiming Normal initialization for linear layers
            if isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight, mode="fan_in", nonlinearity="relu")
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
            ## Identity initialization for batchnorm layers
            elif isinstance(module, nn.BatchNorm1d):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)
                
    def forward(self, x):
        x = self.transformer(x)
        x = self.fc_layers(x)
        x = self.classifier(x)
        return x

In [None]:
!pip install timm

In [None]:
# training model

model = TransformerClassifier().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

history = train_validate_loop(
    EPOCHS, 
    model,
    optimizer, 
    loss_fn, 
    train_loader, 
    val_loader,
    early_stopping_rounds=10
)