In [8]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from torch import nn, optim
import os

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Paths
data_dir = './Data/fits_filtered3/augmented/data/'
train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'val')

# Example Prediction

# streak example test
image_path = os.path.join(data_dir, 'val/streak/tic38.png')
# non-streak example test
# image_path = os.path.join(data_dir, 'val/no_streak/tic5.png')

test_folder_path_streak = os.path.join(data_dir, 'val/streak')  # Change this to the folder you want to test
test_folder_path_nostreak = os.path.join(data_dir, 'val/no_streak')  # Change this to the folder you want to test

# Data Transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ResNet expects 224x224 images
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ImageNet normalization
])

# Load Datasets
train_dataset = ImageFolder(root=train_dir, transform=transform)
val_dataset = ImageFolder(root=val_dir, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# model = torchvision.models.resnet50(pretrained=True) # probably depricated after 13.0 - error/warning received at 06 XII 2024
# Load ResNet50 with weights
model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)  # Load pre-trained weights


# Modify Output Layer for Binary Classification
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)  # 2 classes: streak, no_streak
model = model.to(device)

# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training Loop
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=4):
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_loss, correct, total = 0, 0, 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        
        train_acc = 100. * correct / total
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Train Acc: {train_acc:.2f}%')

        # Validation Phase
        model.eval()
        val_loss, correct, total = 0, 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        val_acc = 100. * correct / total
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Validation Acc: {val_acc:.2f}%')

    return model

# Train the Model
trained_model = train_model(model, train_loader, val_loader, criterion, optimizer, epochs=64)

# Save the Model
torch.save(trained_model.state_dict(), 'streak_detector.pth')

# Inference Function
def predict(image_path, model, transform):
    from PIL import Image
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    output = model(image)
    _, prediction = output.max(1)
    return "streak" if prediction.item() == 0 else "no_streak"

def predict_with_probability(image_path, model, transform):
    from PIL import Image
    model.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)  # Transform and add batch dimension
    output = model(image)  # Get logits
    probabilities = F.softmax(output, dim=1)  # Convert logits to probabilities
    prob_streak, prob_no_streak = probabilities[0].tolist()  # Extract probabilities
    prediction = "streak" if prob_streak > prob_no_streak else "no_streak"
    return prediction, prob_streak, prob_no_streak

# Test all .png files in a chosen folder
def test_folder(folder_path, model, transform):
    """
    Test all .png files in the specified folder.
    Args:
        folder_path (str): Path to the folder containing images.
        model: Trained PyTorch model.
        transform: Image transformation function.
    Returns:
        List of predictions and probabilities for each image.
    """
    results = []  # Store predictions and probabilities
    png_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]

    for png_file in png_files:
        image_path = os.path.join(folder_path, png_file)
        prediction, prob_streak, prob_no_streak = predict_with_probability(image_path, model, transform)
        results.append({
            "image": png_file,
            "prediction": prediction,
            "prob_streak": prob_streak,
            "prob_no_streak": prob_no_streak
        })
        print(f"Image: {png_file}, Prediction: {prediction}, Prob Streak: {prob_streak:.2f}, Prob No Streak: {prob_no_streak:.2f}")
    
    return results
print("streak")
results = test_folder(test_folder_path_streak, trained_model, transform)
print("no_streaks")
results = test_folder(test_folder_path_nostreak, trained_model, transform)




Epoch 1/64, Train Loss: 0.9143, Train Acc: 62.22%
Validation Loss: 23.5594, Validation Acc: 41.18%
Epoch 2/64, Train Loss: 0.4339, Train Acc: 78.52%
Validation Loss: 5.7869, Validation Acc: 41.18%
Epoch 3/64, Train Loss: 0.3808, Train Acc: 85.19%
Validation Loss: 2.9240, Validation Acc: 41.18%
Epoch 4/64, Train Loss: 0.3026, Train Acc: 92.59%
Validation Loss: 3.3921, Validation Acc: 41.18%
Epoch 5/64, Train Loss: 0.2120, Train Acc: 94.07%
Validation Loss: 2.8435, Validation Acc: 41.18%
Epoch 6/64, Train Loss: 0.1751, Train Acc: 93.33%
Validation Loss: 4.5643, Validation Acc: 41.18%
Epoch 7/64, Train Loss: 0.1727, Train Acc: 95.56%
Validation Loss: 0.7494, Validation Acc: 55.88%
Epoch 8/64, Train Loss: 0.1949, Train Acc: 92.59%
Validation Loss: 0.1134, Validation Acc: 94.12%
Epoch 9/64, Train Loss: 0.1129, Train Acc: 96.30%
Validation Loss: 0.1375, Validation Acc: 85.29%
Epoch 10/64, Train Loss: 0.1601, Train Acc: 94.81%
Validation Loss: 0.1059, Validation Acc: 94.12%
Epoch 11/64, Train