In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
from collections import Counter

In [29]:
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [30]:
class GalaxyDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.image_files = [f for f in os.listdir(root_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
        # Store all labels for sampling weights calculation
        self.labels = []
        for img_name in self.image_files:
            label_str = img_name.split('_')[-1].split('.')[0]
            self.labels.append(int(label_str))
        
        # Calculate class weights
        label_counter = Counter(self.labels)
        self.class_weights = {cls: 1.0/count for cls, count in label_counter.items()}
        
        # Store sample weights for WeightedRandomSampler
        self.sample_weights = [self.class_weights[label] for label in self.labels]
        self.sample_weights = torch.FloatTensor(self.sample_weights)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        
        label = self.labels[idx]
        
        # Load and preprocess image
        image = Image.open(img_path).convert("RGB")
        image = image.resize((128, 128))
        
        # Data augmentation for training
        if torch.rand(1) > 0.5:  # Random horizontal flip
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        
        # Convert to numpy array and normalize
        image_array = np.array(image) / 255.0
        
        # Convert to tensor
        image_tensor = torch.FloatTensor(image_array).permute(2, 0, 1)
        
        return image_tensor, label

In [31]:
# Dataset path
data_dir = os.path.expanduser("~/Desktop/BE Project/Decals_data_images")

# Create dataset
dataset = GalaxyDataset(root_dir=data_dir)

# Split dataset
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [32]:
class GalaxyCNN(nn.Module):
    def __init__(self, num_classes):
        super(GalaxyCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(512 * 8 * 8, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

In [33]:
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss):
        score = -val_loss  # for loss, we want to minimize
        
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score - self.delta:  # Score decreased (got worse)
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:  # Score improved
            self.best_score = score
            self.counter = 0

In [34]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs):
    model.train()
    best_val_acc = 0.0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            progress_bar.set_postfix({
                'loss': running_loss/len(train_loader),
                'accuracy': 100.*correct/total
            })
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_accuracy = 100. * val_correct / val_total
        
        print(f'\nEpoch {epoch+1}/{num_epochs}:')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Validation Accuracy: {val_accuracy:.2f}%')
        
        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
        
        # Scheduler step
        scheduler.step(val_accuracy)
        
        # Save best model
        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')

def evaluate_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc='Evaluating'):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100. * correct / total
    print(f'Accuracy: {accuracy:.2f}%')
    return accuracy

In [35]:
# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create data loaders with weighted sampling for training
train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    sampler=WeightedRandomSampler(
        weights=dataset.sample_weights[:train_size],
        num_samples=train_size,
        replacement=True
    )
)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Initialize model and move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GalaxyCNN(num_classes=10).to(device)

# Setup class weights for loss function
class_weights = torch.FloatTensor([
    1.0/1081,  # Class 0
    1.0/1853,  # Class 1
    1.0/2645,  # Class 2
    1.0/2027,  # Class 3
    1.0/334,   # Class 4
    1.0/2043,  # Class 5
    1.0/1829,  # Class 6
    1.0/2628,  # Class 7
    1.0/1423,  # Class 8
    1.0/1873   # Class 9
]).to(device)

# Initialize criterion, optimizer, scheduler, and early stopping
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
early_stopping = EarlyStopping(
    patience=5,  # Reduced patience
    verbose=True,
    delta=0.01  # Minimum change to qualify as an improvement
)
# Modify scheduler parameters
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='max',
    factor=0.1, 
    patience=3,  # Reduced patience
    verbose=True,
    min_lr=1e-6
)

# Train model
num_epochs = 30
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, early_stopping, num_epochs)

# Evaluate the model
evaluate_model(model, test_loader)

# Save the trained model
torch.save(model.state_dict(), "galaxy_cnn_balanced.pth")
print("Model saved successfully!")

Epoch 1/30: 100%|██████████| 444/444 [01:23<00:00,  5.31it/s, loss=2.49, accuracy=21.4] 



Epoch 1/30:
Validation Loss: 1.9806
Validation Accuracy: 30.36%


Epoch 2/30: 100%|██████████| 444/444 [01:33<00:00,  4.75it/s, loss=1.95, accuracy=27.4] 



Epoch 2/30:
Validation Loss: 1.8633
Validation Accuracy: 30.69%


Epoch 3/30: 100%|██████████| 444/444 [01:20<00:00,  5.51it/s, loss=1.8, accuracy=31.8]  



Epoch 3/30:
Validation Loss: 1.6255
Validation Accuracy: 36.25%


Epoch 4/30: 100%|██████████| 444/444 [01:20<00:00,  5.54it/s, loss=1.68, accuracy=36.2] 



Epoch 4/30:
Validation Loss: 1.5622
Validation Accuracy: 42.14%


Epoch 5/30: 100%|██████████| 444/444 [01:20<00:00,  5.54it/s, loss=1.64, accuracy=38]   



Epoch 5/30:
Validation Loss: 1.5864
Validation Accuracy: 39.32%
EarlyStopping counter: 1 out of 5


Epoch 6/30: 100%|██████████| 444/444 [01:20<00:00,  5.54it/s, loss=1.56, accuracy=40.3] 



Epoch 6/30:
Validation Loss: 1.5620
Validation Accuracy: 44.19%


Epoch 7/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=1.49, accuracy=42.9] 



Epoch 7/30:
Validation Loss: 1.4143
Validation Accuracy: 46.11%


Epoch 8/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=1.48, accuracy=43.7] 



Epoch 8/30:
Validation Loss: 1.4078
Validation Accuracy: 49.69%


Epoch 9/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=1.4, accuracy=46.7]  



Epoch 9/30:
Validation Loss: 1.5264
Validation Accuracy: 42.70%
EarlyStopping counter: 1 out of 5


Epoch 10/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=1.38, accuracy=47.5] 



Epoch 10/30:
Validation Loss: 1.3496
Validation Accuracy: 54.14%


Epoch 11/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=1.34, accuracy=49.6] 



Epoch 11/30:
Validation Loss: 1.2709
Validation Accuracy: 50.96%


Epoch 12/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=1.28, accuracy=51.7] 



Epoch 12/30:
Validation Loss: 1.2648
Validation Accuracy: 57.05%


Epoch 13/30: 100%|██████████| 444/444 [01:19<00:00,  5.58it/s, loss=1.24, accuracy=53.6] 



Epoch 13/30:
Validation Loss: 1.3396
Validation Accuracy: 53.64%
EarlyStopping counter: 1 out of 5


Epoch 14/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=1.18, accuracy=55.9] 



Epoch 14/30:
Validation Loss: 1.3484
Validation Accuracy: 47.24%
EarlyStopping counter: 2 out of 5


Epoch 15/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=1.15, accuracy=57.2] 



Epoch 15/30:
Validation Loss: 1.3497
Validation Accuracy: 53.66%
EarlyStopping counter: 3 out of 5


Epoch 16/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=1.12, accuracy=58.9] 



Epoch 16/30:
Validation Loss: 1.1450
Validation Accuracy: 62.49%


Epoch 17/30: 100%|██████████| 444/444 [01:19<00:00,  5.58it/s, loss=1.06, accuracy=60.9] 



Epoch 17/30:
Validation Loss: 1.0811
Validation Accuracy: 65.50%


Epoch 18/30: 100%|██████████| 444/444 [01:19<00:00,  5.58it/s, loss=1.03, accuracy=62.6] 



Epoch 18/30:
Validation Loss: 1.0324
Validation Accuracy: 65.73%


Epoch 19/30: 100%|██████████| 444/444 [01:19<00:00,  5.58it/s, loss=0.979, accuracy=64.6]



Epoch 19/30:
Validation Loss: 0.9793
Validation Accuracy: 66.12%


Epoch 20/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=0.961, accuracy=65]  



Epoch 20/30:
Validation Loss: 1.0227
Validation Accuracy: 64.40%
EarlyStopping counter: 1 out of 5


Epoch 21/30: 100%|██████████| 444/444 [01:20<00:00,  5.50it/s, loss=0.899, accuracy=67.9]



Epoch 21/30:
Validation Loss: 0.9451
Validation Accuracy: 68.15%


Epoch 22/30: 100%|██████████| 444/444 [01:19<00:00,  5.55it/s, loss=0.888, accuracy=67.9]



Epoch 22/30:
Validation Loss: 1.5426
Validation Accuracy: 48.99%
EarlyStopping counter: 1 out of 5


Epoch 23/30: 100%|██████████| 444/444 [01:20<00:00,  5.55it/s, loss=0.861, accuracy=69.4]



Epoch 23/30:
Validation Loss: 0.9562
Validation Accuracy: 69.45%
EarlyStopping counter: 2 out of 5


Epoch 24/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=0.84, accuracy=70.2] 



Epoch 24/30:
Validation Loss: 1.0230
Validation Accuracy: 67.98%
EarlyStopping counter: 3 out of 5


Epoch 25/30: 100%|██████████| 444/444 [01:19<00:00,  5.57it/s, loss=0.803, accuracy=71]  



Epoch 25/30:
Validation Loss: 1.1273
Validation Accuracy: 64.09%
EarlyStopping counter: 4 out of 5


Epoch 26/30: 100%|██████████| 444/444 [01:19<00:00,  5.56it/s, loss=0.777, accuracy=71.9]



Epoch 26/30:
Validation Loss: 1.0327
Validation Accuracy: 64.32%
EarlyStopping counter: 5 out of 5
Early stopping triggered


Evaluating: 100%|██████████| 111/111 [00:13<00:00,  7.94it/s]


Accuracy: 69.81%
Model saved successfully!
