In [49]:
import xarray as xr
import os
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np
import xarray as xr
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.optim import Adam
from sklearn.model_selection import train_test_split
import copy
import random

### Import 

In [50]:
data_path = 'l3_blended_l4_extracts_gR_201703.nc'
ds = xr.open_dataset(data_path)
num_images = 500
rand_indices = np.random.randint(0, len(ds.gRsst), num_images)
ds = ds.isel(i=rand_indices)

### Labelling Code

In [51]:
# Assuming ds.gRsst is your dataset of images
good_images_indices = []
bad_images_indices = []
ns_images_indices = []

# Variable to keep track of the current image index
current_index = 0

# Create an Output widget to display images
image_output = widgets.Output()

# Function to update the display with the current image
def update_display():
    with image_output:
        # Clear previous output in the Output widget
        clear_output(wait=True)
        plt.figure(figsize=(5, 5))
        plt.imshow(ds.gRsst[current_index])  # Adjust cmap based on your image type
        plt.title(f'Image {current_index+1}')
        plt.show()

# Function to handle the "Good" button click
def good_button_clicked(b):
    global current_index
    good_images_indices.append(current_index)
    current_index += 1
    show_next_image()

# Function to handle the "Bad" button click
def bad_button_clicked(b):
    global current_index
    bad_images_indices.append(current_index)
    current_index += 1
    show_next_image()

def ns_button_clicked(b):
    global current_index
    ns_images_indices.append(current_index)
    current_index += 1
    show_next_image()

# Function to show the next image
def show_next_image():
    global current_index
    if current_index < len(ds.gRsst):
        update_display()
    else:
        with image_output:
            clear_output(wait=True)
            print("No more images.")

# Creating the "Good" and "Bad" buttons
good_button = widgets.Button(description="Good")
bad_button = widgets.Button(description="Bad")
ns_button = widgets.Button(description="Not Sure")

good_button.on_click(good_button_clicked)
bad_button.on_click(bad_button_clicked)
ns_button.on_click(ns_button_clicked)

# Group the buttons
buttons = widgets.HBox([good_button, bad_button, ns_button])

# Display the buttons and the Output widget
display(buttons, image_output)

# Initialize the display with the first image
update_display()

HBox(children=(Button(description='Good', style=ButtonStyle()), Button(description='Bad', style=ButtonStyle())…

Output()

### Extract good & bad data

In [62]:
good = np.zeros(len(ds.gRsst))
good[good_images_indices] = 1
ds['good'] = (('i',), good)
# drop unsure cases
ds_filtered = ds.drop_sel(i=ns_images_indices)

IndexError: index 439 is out of bounds for axis 0 with size 439

In [63]:
ds = xr.Dataset({'gRsst': ds_filtered.gRsst,
                 'good': ds_filtered.good})
ds = ds.expand_dims(dim='channel', axis=1)
ds.to_netcdf('quality_training')

### Load and Separate Data

In [112]:
## run if you already have file
data_path = 'quality_training'
ds = xr.open_dataset(data_path)

### Prepare Data

In [64]:
class MyDataset(Dataset): 
    ## data inserted as numpy arrays in form (i, channel, y, x)
    def __init__(self, input_data, output_data):
        self.input_data = input_data
        self.output_data = output_data
    def __getitem__(self, index): 
        x = self.input_data[index]
        y = self.output_data[index]
        return x, y
    def __len__(self):
        return len(self.input_data)

In [65]:
def img_normalise(arr, clipped_gradients = True, normalise_type = None):
    if clipped_gradients:
        mini = -1
        maxi = 1
    else: 
        mini = arr.min(dim={'x','y','i'}) #min for each image in each channel: shape (i, channel)
        maxi = arr.max(dim={'x','y','i'})
    if normalise_type == 'zero_one':
        # 0-1 normalise data
        # use this if you will input your data as numpy arrays
        arr = (arr-mini)/(maxi-mini) #normalise each channel to 0-255
    elif normalise_type == 'tensor': 
        # normalise to 0-1 then convert to image values. 
        # use if you plan to use torch.transforms
        arr = 255*(arr-mini)/(maxi-mini) #normalise each channel to 0-255
    return arr

In [95]:
ds = xr.Dataset({'gRsst':img_normalise(ds.gRsst),
                 'good': ds.good})

input_data = ds.gRsst.data
labels = ds.good.data
print(np.shape(labels))
print(np.shape(input_data))

train_input, test_input, train_labels, test_labels = train_test_split(
    input_data, labels, test_size=0.2)
BATCH_SIZE = 8

train_dataset = MyDataset(train_input, train_labels)
test_dataset = MyDataset(test_input, test_labels)

train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE)

for x, y in train_loader: 
    print(f'batch shape: input: {x.shape}, output: {y.shape}, type: {type(x)}')
    print(f'number of batches: {len(train_loader)}')
    print(f'train images: {len(train_loader.dataset)}, test images:{len(test_loader.dataset)}')
    break

batch shape: input: torch.Size([8, 1, 48, 48]), output: torch.Size([8, 1]), type: <class 'torch.Tensor'>
number of batches: 44
train images: 351, test images:88


### Define Classifier

In [96]:
class BC(nn.Module):
    def __init__(self, input_channels = 1):
        # define layers
        super(BC, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 8, kernel_size=9, stride=1, padding='same')
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding='same')
        self.conv3 = nn.Conv2d(16, 32, kernel_size=5, stride=1, padding='same')
        self.global_pool = nn.AvgPool2d(kernel_size = 48)
        self.fc1 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.sigmoid(self.fc1(x))
        return x

tensor = torch.randn(2,1,48,48)
model = BC()
model(tensor)

tensor([[0.5311],
        [0.5310]], grad_fn=<SigmoidBackward0>)

### Setup Model

In [97]:
## remember to re-run this to restart model training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
loss_fn = nn.BCEWithLogitsLoss()
input_channels = 1
model = BC(input_channels)
model.to(device)
LR = 3E-4
optimiser = Adam(model.parameters(), lr = LR)

cpu


### Train and Assess Model

In [113]:
def train_model(model, train_loader, test_loader, num_epochs, test_interval, loss_fn, optimiser, device, model_path):
    model.to(device)  # Move model to GPU (or CPU if no GPU)
    model.train()     # Set the model to training mode

    best_loss = 1E6
    train_loss_array = []
    train_acc_array = []
    test_loss_array = []
    test_acc_array = []
    for epoch in range(num_epochs): 
        epoch_train_loss = 0
        epoch_train_acc = 0
        epoch_test_loss = 0
        epoch_test_acc = 0 
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimiser.zero_grad()

            outputs = model(inputs)

            loss = loss_fn(outputs, targets)
            epoch_train_loss += loss.item()

            loss.backward()
            optimiser.step()

            preds = (outputs > 0.5).float() 
            num_correct = (preds == targets).sum().item()
            epoch_train_acc += num_correct/targets.size(0)
            
        train_loss_array.append(epoch_train_loss/len(train_loader)) # avg. over epoch
        train_acc_array.append(epoch_train_acc/len(train_loader))
        
        if train_loss_array[epoch] < best_loss: 
            best_model = copy.deepcopy(model)
            best_loss = train_loss_array[epoch]
            torch.save(best_model.state_dict(), model_path)
        
        print(f'epoch: {epoch+1}, train loss: {train_loss_array[epoch]}, train acc: {train_acc_array[epoch]}')
        
        if epoch % test_interval == test_interval - 1: 
            model.eval()
            for i, (inputs, targets) in enumerate(test_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                with torch.nograd():
                    loss = loss_fn(outputs, targets)
                    epoch_test_loss += loss
                    preds = (outputs > 0.5).float() 
                    num_correct = (preds == targets).sum().item()
                    epoch_test_acc += num_correct/targets.size(0)
            test_loss_array.append(epoch_test_loss/len(test_loader)) # avg. over epoch
            test_acc_array.append(epoch_test_acc/len(test_loader))
            plt.subplots(1,2)
            ## plot current stats
            train_epoch_array = range(len(train_loss_array))
            test_epoch_array = range(len(test_loss_array))
            ax[0].plot(train_epoch_array, train_loss_array, label ='test')
            ax[0].plot(test_epoch_array, test_loss_array, label='test')
            ax[0].set_ylabel('loss')
            ax[0].set_xlabel('epoch')
            ax[0].legend()
            ax[1].plot(train_epoch_array, train_acc_array)
            ax[1].plot(test_epoch_array, test_acc_array)
            ax[1].set_ylabel('acc')
            ax[1].set_xlabel('epoch')
        return best_model, train_loss_array, train_acc_array, test_loss_array, test_acc_array
    

num_epochs = 500
test_interval = 10 
model_path = 'classifier'
best_model, train_loss_array, train_acc_array, test_loss_array, test_acc_array = train_model(model, train_loader, test_loader, 
                                                                                            num_epochs, test_interval, loss_fn, optimiser, device, model_path)

epoch: 1, train loss: 0.6931471906372867, train acc: 0.46915584415584416
