In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch import nn, optim
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
data_dir = './Data/fits_filtered3/augmented/data/'  # Update this path to your dataset directory
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')


# Example Prediction

# streak example test
image_path = os.path.join(data_dir, 'val/streak/tic38.png')
# image_path = './Data/fits_filtered4/augmented/data/val/streak/tic6_2.png'  # Replace with an actual image path
# non-streak example test
# image_path = os.path.join(data_dir, 'val/no_streak/tic5.png')
# image_path = './Data/fits_filtered4/data/val/no_streak/tic5.png'  # Replace with an actual image path

test_folder_path_streak = os.path.join(data_dir, 'val/streak')  # Change this to the folder you want to test
test_folder_path_nostreak = os.path.join(data_dir, 'val/no_streak')  # Change this to the folder you want to test

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

# Load Datasets
train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Load Pre-Trained ResNet
# model = torchvision.models.resnet50(pretrained=True) # probably depricated after 13.0 - error/warning received at 06 XII 2024
# Load ResNet50 with weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)  # Load pre-trained weights


# Modify Output Layer for Binary Classification
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # 2 classes: streak, no_streak
model = model.to(device)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=4):
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_loss, correct, total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_acc = 100. * correct / total
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%')

        # Validation Phase
        model.eval()
        val_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_acc = 100. * correct / total
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Acc: {val_acc:.2f}%')

    return model

# Train the Model
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, epochs=20)

# Save the Model
torch.save(trained_model.state_dict(), 'streak_detector.pth')

# Inference Function
def predict(image_path, model, transform):
    from PIL import Image
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    output = model(image)
    _, prediction = output.max(1)
    return "streak" if prediction.item() == 0 else "no_streak"

def predict_with_probability(image_path, model, transform):
    from PIL import Image
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # Transform and add batch dimension
    output = model(image)  # Get logits
    probabilities = F.softmax(output, dim=1)  # Convert logits to probabilities
    prob_streak, prob_no_streak = probabilities[0].tolist()  # Extract probabilities
    prediction = "streak" if prob_streak > prob_no_streak else "no_streak"
    return prediction, prob_streak, prob_no_streak

# Test all .png files in a chosen folder
def test_folder(folder_path, model, transform):
    """
    Test all .png files in the specified folder.
    Args:
        folder_path (str): Path to the folder containing images.
        model: Trained PyTorch model.
        transform: Image transformation function.
    Returns:
        List of predictions and probabilities for each image.
    """
    results = []  # Store predictions and probabilities
    png_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

    for png_file in png_files:
        image_path = os.path.join(folder_path, png_file)
        prediction, prob_streak, prob_no_streak = predict_with_probability(image_path, model, transform)
        results.append({
            "image": png_file,
            "prediction": prediction,
            "prob_streak": prob_streak,
            "prob_no_streak": prob_no_streak
        })
        print(f"Image: {png_file}, Prediction: {prediction}, Prob Streak: {prob_streak:.2f}, Prob No Streak: {prob_no_streak:.2f}")
    
    return results

results = test_folder(test_folder_path_streak, trained_model, transform)
results = test_folder(test_folder_path_nostreak, trained_model, transform)


# prediction, prob_streak, prob_no_streak = predict_with_probability(image_path, trained_model, transform)

# print(f"Prediction: {prediction}")
# print(f"Probability of Streak: {prob_streak:.2f}")
# print(f"Probability of No Streak: {prob_no_streak:.2f}")


# print(f"Prediction: {predict(example_image, trained_model, transform)}")


Epoch 1/20, Train Loss: 0.9447, Train Acc: 66.67%
Validation Loss: 7.8869, Validation Acc: 73.53%
Epoch 2/20, Train Loss: 0.6431, Train Acc: 78.52%
Validation Loss: 8.3052, Validation Acc: 73.53%
Epoch 3/20, Train Loss: 0.3932, Train Acc: 84.44%
Validation Loss: 5.1041, Validation Acc: 79.41%
Epoch 4/20, Train Loss: 0.2226, Train Acc: 91.85%
Validation Loss: 1.8302, Validation Acc: 88.24%
Epoch 5/20, Train Loss: 0.3758, Train Acc: 91.85%
Validation Loss: 0.8136, Validation Acc: 88.24%
Epoch 6/20, Train Loss: 0.1460, Train Acc: 94.07%
Validation Loss: 0.1280, Validation Acc: 91.18%
Epoch 7/20, Train Loss: 0.1841, Train Acc: 94.81%
Validation Loss: 4.0836, Validation Acc: 41.18%
Epoch 8/20, Train Loss: 0.2173, Train Acc: 95.56%
Validation Loss: 2.1750, Validation Acc: 41.18%
Epoch 9/20, Train Loss: 0.1319, Train Acc: 96.30%
Validation Loss: 0.2596, Validation Acc: 91.18%
Epoch 10/20, Train Loss: 0.0827, Train Acc: 97.04%
Validation Loss: 0.6568, Validation Acc: 79.41%
Epoch 11/20, Train 