<a href="https://colab.research.google.com/github/aishstronomer/flare-finder/blob/main/flare_classifier.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
# set globals

# do Google Colab things
try:
    from google.colab import drive

    drive.mount("/content/drive", force_remount=True)
    IN_COLAB = True
except:
    IN_COLAB = False

# install dependencies
path_to_coderepo = (
    "/content/drive/MyDrive/ML_project/code_repo/flare-finder" if IN_COLAB else "."
)
if IN_COLAB:
    !pip install -r {path_to_coderepo}/requirements.txt


Mounted at /content/drive


In [5]:
# set globals

# import standard libraries
from sklearn.model_selection import train_test_split
import os
import pandas as pd

In [19]:
# creating model class for big flare prediction
from PIL import Image
import torch
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


class BigFlareFinder:
    def __init__(self):
        self.pytorch_model = None

    def fit(self, image_paths, image_labels, val_frac=0.2):

        # split the data into train-validation using sklearn and use stratified sampling
        image_paths_train, image_paths_val, image_labels_train, image_labels_val = train_test_split(
            image_paths, image_labels, test_size=val_frac, random_state=42, stratify=image_labels)

        # for training, augment minority-class by making copies
        image_label_counts = pd.Series(image_labels).value_counts().sort_values()
        minority_class, majority_class = tuple(image_label_counts.index)
        class_count_diff = image_label_counts[majority_class] - image_label_counts[minority_class]
        image_paths_train_new = pd.Series(image_paths_train[image_labels_train == minority_class]).sample(
            class_count_diff, replace=True, random_state=42).to_list()
        image_labels_train_new = [minority_class for _ in range(class_count_diff)]
        image_paths_train = image_paths_train + image_paths_train_new
        image_labels_train = image_labels_train + image_labels_train_new

        # get dataloader for train and validation data
        train_loader = BigFlareFinder.preprocess(image_paths_train, image_labels_train)
        validation_loader = BigFlareFinder.preprocess(image_paths_val, image_labels_val)

        # fit pytorch model using dataloader

        # load resnet18 model
        model_resnet18 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)

        # Freeze all params except the BatchNorm layers, as here they are trained to the
        # mean and standard deviation of ImageNet and we may lose some signal
        for name, param in model_resnet18.named_parameters():
            if("bn" not in name):
                param.requires_grad = False

        # reduce number of output classes in model
        num_classes = 2
        model_resnet18.fc = nn.Sequential(nn.Linear(model_resnet18.fc.in_features,512),
                                        nn.ReLU(),
                                        nn.Dropout(),
                                        nn.Linear(512, num_classes))

        if torch.cuda.is_available():
            device = torch.device("cuda")
        else:
            device = torch.device("cpu")

        model_resnet18.to(device)
        optimizer = optim.Adam(model_resnet18.parameters(), lr=0.001)
        loss_fn = torch.nn.CrossEntropyLoss()
        epochs = 10
        device = device
        target_class = 1

        for epoch in range(epochs):
            training_loss = 0.0
            valid_loss = 0.0
            model_resnet18.train()
            for batch in train_loader:
                optimizer.zero_grad()
                inputs, targets = batch
                targets = targets.type(torch.LongTensor)
                inputs = inputs.to(device)
                targets = targets.to(device)
                output = model_resnet18(inputs)
                loss = loss_fn(output, targets)

                loss.backward()
                optimizer.step()
                training_loss += loss.data.item() * inputs.size(0)
            training_loss /= len(train_loader.dataset)

            model_resnet18.eval()
            all_targets = []
            all_predictions = []

            for batch in validation_loader:
                inputs, targets = batch
                targets = targets.type(torch.LongTensor)
                inputs = inputs.to(device)
                output = model_resnet18(inputs)
                targets = targets.to(device)
                loss = loss_fn(output,targets)

                valid_loss += loss.data.item() * inputs.size(0)
                predictions = torch.max(F.softmax(output, dim=1), dim=1)[1]
                correct = torch.eq(predictions, targets).view(-1)
                all_targets.extend(targets.cpu().numpy())
                all_predictions.extend(predictions.cpu().numpy())

            valid_loss /= len(validation_loader.dataset)

            # Debug statements
            print(f"Targets distribution: {dict(zip(*np.unique(all_targets, return_counts=True)))}")
            print(f"Predictions distribution: {dict(zip(*np.unique(all_predictions, return_counts=True)))}")

            print(
                f'Epoch: {epoch}, train_loss: {round(training_loss, 2)}'
                f', val_metrics: {BigFlareFinder.get_model_performance_metrics(all_targets, all_predictions)}')

        # TODO: train on the val data (which was excluded from training earlier)
        # - augment the val data
        # - get a new val loader
        # - train on the val loader

        self.pytorch_model = model_resnet18

    def predict(self, image_paths):
        self.pytorch_model.eval()
        pred_image_labels = [None]*len(image_paths)




        return pred_image_labels

    @staticmethod
    def get_model_performance_metrics(y_true, y_pred):
        metrics_dict = {
            "accuracy": round(accuracy_score(y_true, y_pred), 2),
            "f1": round(f1_score(y_true, y_pred), 2),
            "precision_class_1": round(precision_score(y_true, y_pred), 2),
            "recall_class_1": round(recall_score(y_true, y_pred), 2),
        }
        return metrics_dict

    @staticmethod
    def preprocess(image_paths, image_labels):
        # make dataset of image_paths and image_labels
        image_dimension = 224
        image_transforms = transforms.Compose([
            transforms.Resize((image_dimension, image_dimension)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
            ])
        dataset = CustomImageDataset(image_paths, image_labels, image_transforms)

        # make dataloader for dataset
        batch_size = 32
        num_workers = 2
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

        return dataloader


# run the model on some data

# get image paths and labels
notbigflare_max_count = 200
bigflare_max_count = 30
image_folder_path = f"{path_to_coderepo}/../../data/sdo_images"
solar_image_path_df = pd.read_csv(f"{path_to_coderepo}/big_flare_labels.csv").dropna()
solar_image_path_df = pd.concat(
    [
        solar_image_path_df[solar_image_path_df["is_big_flare"] == 0][0:notbigflare_max_count],
        solar_image_path_df[solar_image_path_df["is_big_flare"] == 1][0:bigflare_max_count],
    ]
)
image_paths = (image_folder_path + '/' + solar_image_path_df['solar_image_filename']).to_list()
image_labels = solar_image_path_df['is_big_flare'].to_list()

# split data into train-test using sklearn and use stratified sampling
image_paths_train, image_paths_test, image_labels_train, image_labels_test = train_test_split(
    image_paths, image_labels, test_size=0.2, random_state=42, stratify=image_labels
)

# train model
big_flare_finder = BigFlareFinder()
model = big_flare_finder.fit(image_paths_train, image_labels_train)

# make predictions
pred_image_labels = big_flare_finder.predict(image_paths_test)
pred_image_labels


Using cache found in /root/.cache/torch/hub/pytorch_vision_main


Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 17, 1: 20}
Epoch: 0, train_loss: 0.41, val_metrics: {'accuracy': 0.32, 'f1': 0.0, 'precision_class_1': 0.0, 'recall_class_1': 0.0}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 1, train_loss: 0.24, val_metrics: {'accuracy': 0.84, 'f1': 0.0, 'precision_class_1': 0.0, 'recall_class_1': 0.0}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 37}
Epoch: 2, train_loss: 0.19, val_metrics: {'accuracy': 0.86, 'f1': 0.0, 'precision_class_1': 0.0, 'recall_class_1': 0.0}


  _warn_prf(average, modifier, msg_start, len(result))


Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 3, train_loss: 0.24, val_metrics: {'accuracy': 0.89, 'f1': 0.33, 'precision_class_1': 1.0, 'recall_class_1': 0.2}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 37}
Epoch: 4, train_loss: 0.14, val_metrics: {'accuracy': 0.86, 'f1': 0.0, 'precision_class_1': 0.0, 'recall_class_1': 0.0}


  _warn_prf(average, modifier, msg_start, len(result))


Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 5, train_loss: 0.09, val_metrics: {'accuracy': 0.89, 'f1': 0.33, 'precision_class_1': 1.0, 'recall_class_1': 0.2}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 6, train_loss: 0.08, val_metrics: {'accuracy': 0.89, 'f1': 0.33, 'precision_class_1': 1.0, 'recall_class_1': 0.2}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 7, train_loss: 0.05, val_metrics: {'accuracy': 0.89, 'f1': 0.33, 'precision_class_1': 1.0, 'recall_class_1': 0.2}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 35, 1: 2}
Epoch: 8, train_loss: 0.06, val_metrics: {'accuracy': 0.92, 'f1': 0.57, 'precision_class_1': 1.0, 'recall_class_1': 0.4}
Targets distribution: {0: 32, 1: 5}
Predictions distribution: {0: 36, 1: 1}
Epoch: 9, train_loss: 0.06, val_metrics: {'accuracy': 0.89, 'f1': 0.33, 'precision_class_1': 1.0, 'recall_class_1': 0.2}


RuntimeError: The size of tensor a (4) must match the size of tensor b (3) at non-singleton dimension 0

In [None]:
df_test["is_big_flare"].value_counts(dropna=False)

is_big_flare
0.0    6109
NaN    2323
1.0     118
Name: count, dtype: int64

In [None]:
import os, shutil

In [None]:
# Load the Drive helper and mount
from google.colab import drive

# This will prompt for authorization.
drive.mount('/content/drive', force_remount=True)

In [None]:
# After executing the cell above, Drive
# files will be present in "/content/drive/My Drive".
!ls "/content/drive/My Drive"

In [None]:
import pandas as pd

path_to_data = '/content/drive/My Drive/ML_project/data/sdo_images/'
big_flare_labels_filename = 'big_flare_labels.csv'
big_flare_labels_filpepath = os.path.join(path_to_data, big_flare_labels_filename)
solar_image_path_df = pd.read_csv(big_flare_labels_filpepath)
big_flare_paths = solar_image_path_df[solar_image_path_df['is_big_flare'] == 1]['solar_image_filename'].apply(lambda x: os.path.join(path_to_data, x)).tolist()
not_big_flare_paths = solar_image_path_df[solar_image_path_df['is_big_flare'] == 0]['solar_image_filename'].apply(lambda x: os.path.join(path_to_data, x)).tolist()

len(big_flare_paths), len(not_big_flare_paths)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
model_resnet18 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True)
model_resnet34 = torch.hub.load('pytorch/vision', 'resnet34', pretrained=True)

In [None]:
# Freeze all params except the BatchNorm layers, as here they are trained to the
# mean and standard deviation of ImageNet and we may lose some signal
for name, param in model_resnet18.named_parameters():
    if("bn" not in name):
        param.requires_grad = False

for name, param in model_resnet34.named_parameters():
    if("bn" not in name):
        param.requires_grad = False

In [None]:
# Replace the classifier
num_classes = 2

model_resnet18.fc = nn.Sequential(nn.Linear(model_resnet18.fc.in_features,512),
                                  nn.ReLU(),
                                  nn.Dropout(),
                                  nn.Linear(512, num_classes))

model_resnet34.fc = nn.Sequential(nn.Linear(model_resnet34.fc.in_features,512),
                                  nn.ReLU(),
                                  nn.Dropout(),
                                  nn.Linear(512, num_classes))

In [None]:
from sklearn.metrics import precision_score, recall_score
import numpy as np

def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=5, device="cpu", target_class=1):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)

        model.eval()
        num_correct = 0
        num_examples = 0
        all_targets = []
        all_predictions = []

        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = loss_fn(output,targets)
            valid_loss += loss.data.item() * inputs.size(0)

            predictions = torch.max(F.softmax(output, dim=1), dim=1)[1]
            correct = torch.eq(predictions, targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]

            all_targets.extend(targets.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

        valid_loss /= len(val_loader.dataset)

        # Calculate precision and recall for the target class
        precision = precision_score(all_targets, all_predictions, pos_label=target_class, average='binary', zero_division=0)
        recall = recall_score(all_targets, all_predictions, pos_label=target_class, average='binary', zero_division=0)

        # Debug statements
        print(f"Targets distribution: {dict(zip(*np.unique(all_targets, return_counts=True)))}")
        print(f"Predictions distribution: {dict(zip(*np.unique(all_predictions, return_counts=True)))}")

        print('Epoch: {}, Training Loss: {:.4f}, Validation Loss: {:.4f}, Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}'.format(
            epoch, training_loss, valid_loss, num_correct / num_examples, precision, recall))

In [None]:
import random
from torch.utils.data import Dataset, DataLoader, random_split

# Custom dataset class
class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


paths_class1 = not_big_flare_paths[:70]
paths_class2 = big_flare_paths[:30]
split_fractions = [0.7, 0.1, 0.2]

def get_dataloaders(paths_class1, paths_class2, split_fractions):
  # split the paths list into subsets for class 1 and class 2
  paths_class1_split = [list(subset) for subset in random_split(paths_class1, split_fractions)]
  paths_class2_split = [list(subset) for subset in random_split(paths_class2, split_fractions)]

  # defining inputs to the DataLoader function
  img_dimensions = 224
  img_transforms = transforms.Compose([
      transforms.Resize((img_dimensions, img_dimensions)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225] )
      ])
  batch_size = 32
  num_workers = 2

  # function to get dataloader from class subsets
  def get_dataloader(class1_subset, class2_subset):
    all_subset = class1_subset + class2_subset
    class1_subset_labels = len(class1_subset)*[0]
    class2_subset_labels = len(class2_subset)*[1]
    all_subset_labels = class1_subset_labels + class2_subset_labels
    subset_dataset = CustomImageDataset(all_subset, all_subset_labels, transform=img_transforms)
    subset_dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    return subset_dataloader

  data_loaders = [get_dataloader(class1_subset, class2_subset) for class1_subset, class2_subset in zip(paths_class1_split, paths_class2_split)]

  return data_loaders

train_data_loader, validation_data_loader, test_data_loader = get_dataloaders(paths_class1, paths_class2, split_fractions)
train_data_loader, validation_data_loader, test_data_loader

In [None]:
print(f'Num training images: {len(train_data_loader.dataset)}')
print(f'Num validation images: {len(validation_data_loader.dataset)}')
print(f'Num test images: {len(test_data_loader.dataset)}')

## Train and test the models

In [None]:
def test_model(model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_data_loader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('correct: {:d}  total: {:d}'.format(correct, total))
    print('accuracy = {:f}'.format(correct / total))

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [None]:
model_resnet18.to(device)
optimizer = optim.Adam(model_resnet18.parameters(), lr=0.001)
train(model_resnet18, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, validation_data_loader, epochs=2, device=device)

Targets distribution: {0: 7, 1: 3} that in the validation set, there are 7 instances of class 0 and 3 instances of class 1. This tells you the actual distribution of classes in your validation data.

Predictions distribution: {0: 9, 1: 1} indicates that the model predicted 9 instances as class 0 and 1 instance as class 1.


In [None]:
test_model(model_resnet18)

In [None]:
model_resnet34.to(device)
optimizer = optim.Adam(model_resnet34.parameters(), lr=0.001)
train(model_resnet34, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, validation_data_loader, epochs=2, device=device)

In [None]:
test_model(model_resnet34)


## Make some predictions


In [None]:
import os
def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    return classes, class_to_idx

def make_prediction(model, filename):
    labels, _ = find_classes('/content/drive/My Drive/ML_project/dogs-vs-cats/test')
    img = Image.open(filename)
    img = img_test_transforms(img)
    img = img.unsqueeze(0)
    prediction = model(img.to(device))
    prediction = prediction.argmax()
    print(labels[prediction])

# make_prediction(model_resnet34, '/content/drive/My Drive/ML_project/dogs-vs-cats/test/dogs/dog.1146.jpg')
# make_prediction(model_resnet34, '/content/drive/My Drive/ML_project/dogs-vs-cats/test/cats/cat.1226.jpg')

In [None]:
torch.save(model_resnet18.state_dict(), "./model_resnet18.pth")
torch.save(model_resnet34.state_dict(), "./model_resnet34.pth")


## Load the models from disk and test with an ensemble

In [None]:
# Remember that you must call model.eval() to set dropout and batch normalization layers to
# evaluation mode before running inference. Failing to do this will yield inconsistent inference result

resnet18 = torch.hub.load('pytorch/vision', 'resnet18')
resnet18.fc = nn.Sequential(nn.Linear(resnet18.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, num_classes))
resnet18.load_state_dict(torch.load('./model_resnet18.pth'))
resnet18.eval()

resnet34 = torch.hub.load('pytorch/vision', 'resnet34')
resnet34.fc = nn.Sequential(nn.Linear(resnet34.fc.in_features,512),nn.ReLU(), nn.Dropout(), nn.Linear(512, num_classes))
resnet34.load_state_dict(torch.load('./model_resnet34.pth'))
resnet34.eval()

print("done")

In [None]:
# Test against the average of each prediction from the two models
models_ensemble = [resnet18.to(device), resnet34.to(device)]
correct = 0
total = 0
with torch.no_grad():
    for data in test_data_loader:
        images, labels = data[0].to(device), data[1].to(device)
        predictions = [i(images).data for i in models_ensemble]
        avg_predictions = torch.mean(torch.stack(predictions), dim=0)
        _, predicted = torch.max(avg_predictions, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('accuracy = {:f}'.format(correct / total))
print('correct: {:d}  total: {:d}'.format(correct, total))

In [None]:
# Assuming your model and data are on the same device (e.g., 'cuda' or 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet34.to(device)

# Example usage
make_prediction(resnet34, '/content/drive/My Drive/ML_project/dogs-vs-cats/test/dogs/dog.1146.jpg')
make_prediction(resnet34, '/content/drive/My Drive/ML_project/dogs-vs-cats/test/cats/cat.1226.jpg')