# This notebook is under construction. 

- In this notebook a rather simple 1-D CNN is trained and tested on H-alpha diagnostics. 
- Later this 1-D CNN may be used as a supplementary model for the ensembled RIS1xRIS2 resp. RIS1xRIS1 model, which have poor performance in distinguishing H-modes from ELMs.


- Functions written here will migrate to `confinement_mode_classifier.py` once tested

In [1]:
import os
from pathlib import Path
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.backends.backend_agg import FigureCanvasAgg
import re
import seaborn as sns
import torch
import torch.nn.functional as F
import pandas as pd
import torchvision
from tqdm.notebook import tqdm
import pytorch_lightning as pl
import confinement_mode_classifier as cmc
from torchvision.io import read_image
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler
from torchmetrics.classification import MulticlassConfusionMatrix, F1Score, MulticlassPrecision, MulticlassRecall, MulticlassPrecisionRecallCurve, MulticlassROC
from torch.optim import lr_scheduler
import torch.nn as nn
import copy
from tempfile import TemporaryDirectory
from torch.utils.tensorboard import SummaryWriter
import time 
from datetime import datetime
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
path = Path(os.getcwd())
data_dir_path = f'{path}/data/LH_alpha'
file_names = os.listdir(data_dir_path)

In [3]:
batch_size = 32

#Time window for the diagnostics
h_alpha_window = 50

#Shots used in training
shot_numbers = [re.search(r'shot_(\d+)', file_name).group(1) for file_name in file_names]
shots_for_testing = ['18130', '16773', '16534', 
                     '19094', '18133', '17837', 
                     '18128', '19915', '19925', 
                     '13182', '20009', '20112'
                     ]

shots_for_validation = ['16769', '19379', '18057', 
                        '18132', '18261', '18267', 
                        '18260', '20143', '20145', 
                        '20146', '20147', '20144', 
                        '20098'
                        ]


shot_df, test_df, val_df, train_df = cmc.load_and_split_dataframes(path, shot_numbers, 
                                                                   shots_for_testing, 
                                                                   shots_for_validation, 
                                                                   use_ELMS=True)

#Test dloader is not balanced -> testing the ability to define ELM as anomalies
test_dataloader = cmc.get_dloader(test_df, path=path, batch_size=batch_size, 
                                    balance_data=False, only_halpha=True, 
                                    second_img_opt=None, shuffle=False,
                                    h_alpha_window = h_alpha_window)

val_dataloader = cmc.get_dloader(val_df, path=path, batch_size=batch_size, 
                                    balance_data=True, only_halpha=True, 
                                    second_img_opt=None, shuffle=False,
                                    h_alpha_window = h_alpha_window)

train_dataloader = cmc.get_dloader(train_df, path=path, batch_size=batch_size, 
                                    balance_data=True, only_halpha=True, 
                                    second_img_opt=None, shuffle=False,
                                    h_alpha_window = h_alpha_window)

Create a model class

In [4]:
class Simple1DCNN(nn.Module):
    def __init__(self, num_classes=3, h_alpha_window=80):
        super(Simple1DCNN, self).__init__()
        # Define the 1D convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1, dilation=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=2)
        self.batch_norm1 = nn.BatchNorm1d(32)
        # Define a fully connected layer for classification
        ### in_features = floor[((input_length + 2*padding - dilation*(kernel_size - 1) - 1) // stride) + 1]
        self.fc = nn.Linear(in_features=32 * (h_alpha_window - 2), out_features=num_classes)

    def forward(self, x):
        # Apply 1D convolutions
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.batch_norm1(x)  #!!! should I use some activation function here?

        # Flatten the tensor for the fully connected layer
        x = x.view(x.size(0), -1)

        # Apply the fully connected layer and return the output
        x = self.fc(x)
        return x

: 

- This `train_model()` function is just a modified function copied from confinement_mode_classifier.py
- The main difference resides in how the input data are parsed to the model (batch of all the diagnostics vs batch of just RIS imgs)
- Will have to generalize all the models in order to use a single function for all

In [9]:
def train_model(model, criterion, optimizer, scheduler:lr_scheduler, dataloaders: dict,
                 writer: SummaryWriter, dataset_sizes={'train':1, 'val':1}, num_epochs=25,
                 chkpt_path=os.getcwd()):
    since = time.time()


    torch.save(model.state_dict(), chkpt_path)
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            num_of_samples = 0
            running_batch = 0
            # Iterate over data.
            #TODO: eliminate the need in that dummy iterative for tensorboard part
            for batch in tqdm(dataloaders[phase]):
                
                inputs = batch['h_alpha'].to(device).float() # #TODO: is it smart to convert double to float here? 
                labels = batch['label'].to(device)
                
                running_batch += 1
                
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs) #2D tensor with shape Batchsize*len(modes)
                    #TODO: inputs.type. 
                    _, preds = torch.max(outputs, 1) #preds = 1D array of indicies of maximum values in row. ([2,1,2,1,2]) - third feature is largest in first sample, second in second...
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                num_of_samples += inputs.size(0)
                running_corrects += torch.sum(preds == labels.data) #How many correct answers
                
                
                #tensorboard part
                
                if running_batch % int(len(dataloaders[phase])/10)==int(len(dataloaders[phase])/10)-1: 
                    # ...log the running loss
                    
                    #Training/validation loss
                    writer.add_scalar(f'{phase}ing loss',
                                    loss,
                                    epoch * len(dataloaders[phase]) + running_batch)
                    
                    #F1 metric
                    writer.add_scalar(f'{phase}ing F1 metric',
                                    F1Score(task="multiclass", num_classes=3).to(device)(preds, labels),
                                    epoch * len(dataloaders[phase]) + running_batch)
                    
                    #Precision recall
                    writer.add_scalar(f'{phase}ing macro Precision', 
                                        MulticlassPrecision(num_classes=3).to(device)(preds, labels),
                                        epoch * len(dataloaders[phase]) + running_batch)
                    
                    writer.add_scalar(f'{phase}ing macro Recall', 
                                        MulticlassRecall(num_classes=3).to(device)(preds, labels),
                                        epoch * len(dataloaders[phase]) + running_batch)
                    
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                writer.add_scalar(f'best_accuracy for epoch',
                                    epoch_acc,
                                    epoch)
                writer.close()
                best_acc = epoch_acc
                torch.save(model.state_dict(), chkpt_path)


        time_elapsed = time.time() - since


        # load best model weights
        model.load_state_dict(torch.load(chkpt_path))
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')
    return model

In [10]:
dataloaders = {'train':train_dataloader, 'val':val_dataloader}
dataset_sizes = {x: len(dataloaders[x].dataset) for x in ['train', 'val']}

timestamp =  datetime.fromtimestamp(time.time()).strftime("%y-%m-%d, %H-%M-%S ") + input('add comment: ')
# create grid of images
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter(f'runs/{timestamp}')
model_path = Path(f'{path}/runs/{timestamp}/model.pt')

In [11]:
untrained_cnn = Simple1DCNN(h_alpha_window=h_alpha_window)
untrained_cnn = untrained_cnn.to(device)

In [40]:
sample_input = next(iter(train_dataloader))['h_alpha'].to(device).float()
writer.add_graph(untrained_cnn, sample_input)

In [None]:

criterion = nn.CrossEntropyLoss()
# Observe that all parameters are being optimized
optimizer = torch.optim.Adam(untrained_cnn.parameters(), lr=1e-3) #pouzit adam

exp_lr_scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=1e-2, total_steps=50) #!!!

num_epochs = 16
trained_cnn = train_model(untrained_cnn, criterion, optimizer, exp_lr_scheduler, 
                       dataloaders, writer, dataset_sizes, num_epochs=num_epochs, 
                       chkpt_path = model_path.with_name(f'{model_path.stem}_chkpt{model_path.suffix}'))

torch.save(trained_cnn.state_dict(), model_path)

- Again, `test_model()` is modified function from `cmc` 
- Again main difference is how the batch is processed

In [42]:
def test_model(run_path, model: torchvision.models.resnet.ResNet, test_dataloader: DataLoader,
                max_batch: int = 0, return_metrics: bool = True, comment: str =''):
    '''
    Takes model and dataloader and returns figure with confusion matrix, 
    dataframe with predictions, F1 metric value, precision, recall and accuracy

    Args:
        model: ResNet model
        test_dataloader: DataLoader used for testing
        max_batch: maximum number of bathces to use for testing. Set = 0 to use all batches in DataLoader
        return_metrics: if True returns confusion matrix, F1, precision, recall and accuracy 
    
    Returns: 
        preds: pd.DataFrame() pd.DataFrame with columns of predicted class, true class, frame time and confidence of the prediction
        precision: MulticlassPrecision(num_classes=3)
        recall: MulticlassRecall(num_classes=3)
        accuracy: (TP+TN)/(TP+TN+FN+FP)
        fig_confusion_matrix: MulticlassConfusionMatrix(num_classes=3)
    '''
    y_df = torch.tensor([])
    y_hat_df = torch.tensor([])
    preds = pd.DataFrame(columns=['shot', 'prediction', 'label', 'time', 'confidence', 'L_logit', 'H_logit', 'ELM_logit'])
    pattern = re.compile(r'RIS1_(\d+)_t=')
    batch_index = 0 #iterator
    for batch in tqdm(test_dataloader, desc='Processing batches'):
        batch_index +=1
        outputs, y_hat, confidence = cmc.images_to_probs(model, batch['h_alpha'].to(device).float())
        y_hat = torch.tensor(y_hat)
        y_df = torch.cat((y_df.int(), batch['label']), dim=0)
        y_hat_df = torch.cat((y_hat_df, y_hat), dim=0)
        shot_numbers = [int(pattern.search(path).group(1)) for path in batch['path']]

        pred = pd.DataFrame({'shot': shot_numbers, 'prediction': y_hat.data, 
                            'label': batch['label'].data, 'time':batch['time'], 
                            'confidence': confidence,'L_logit': outputs[:,0].cpu(), 
                            'H_logit': outputs[:,1].cpu(), 'ELM_logit': outputs[:,2].cpu()})

        preds = pd.concat([preds, pred],axis=0, ignore_index=True)

        if max_batch!=0 and batch_index>max_batch:
            break

    if return_metrics:
        softmax_out = torch.nn.functional.softmax(torch.tensor(preds[['L_logit','H_logit','ELM_logit']].values), dim=1)
        #Confusion matrix
        confusion_matrix_metric = MulticlassConfusionMatrix(num_classes=3)
        confusion_matrix_metric.update(y_hat_df, y_df)
        conf_matrix_fig, conf_matrix_ax  = confusion_matrix_metric.plot()
        #F1
        f1 = F1Score(task="multiclass", num_classes=3)(y_hat_df, y_df)

        #Precision
        precision = MulticlassPrecision(num_classes=3)(y_hat_df, y_df)
        recall = MulticlassRecall(num_classes=3)(y_hat_df, y_df)
        #precision(logits_df, y_df.int())
         #Precision_recall curve
        pr_curve = MulticlassPrecisionRecallCurve(num_classes=3, thresholds=64)
        pr_curve.update(softmax_out, y_df)
        pr_curve_fig, pr_curve_ax = pr_curve.plot(score=True)
        #ROC metric
        mcroc = MulticlassROC(num_classes=3, thresholds=64)
        mcroc.update(torch.tensor(preds[['L_logit', 'H_logit', 'ELM_logit']].values.astype(float)), y_df)
        roc_fig, roc_ax = mcroc.plot(score=True)
        #Accuracy
        accuracy = len(preds[preds['prediction']==preds['label']])/len(preds)

        textstr = '\n'.join((
            f'Whole test dset',
            r'threshhold = 0.5:',
            r'f1=%.2f' % (f1.item(), ),
            r'precision=%.2f' % (precision.item(), ),
            r'recall=%.2f' % (recall.item(), ),
            r'accuracy=%.2f' % (accuracy, )))
        # these are matplotlib.patch.Patch properties
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        
        conf_matrix_ax.set_title(f'confusion matrix for whole test dset')
        pr_curve_ax.set_title(f'pr_curve for whole test dset')
        pr_curve_ax.set_xlabel('Precision')
        pr_curve_ax.set_ylabel('Recall')
        roc_ax.text(0.05, 0.3, textstr, fontsize=14, verticalalignment='bottom', bbox=props)
        roc_ax.set_xlabel('FP Rate')
        roc_ax.set_ylabel('TP Rate')


        # Open the saved images using Pillow
        roc_img = cmc.matplotlib_figure_to_pil_image(roc_fig)
        conf_matrix_img = cmc.matplotlib_figure_to_pil_image(conf_matrix_fig)
        pr_curve_img = cmc.matplotlib_figure_to_pil_image(pr_curve_fig)
        combined_image = Image.new('RGB', (conf_matrix_img.width + pr_curve_img.width + roc_img.width,\
                                            conf_matrix_img.height))

        # Paste the saved images into the combined image
        combined_image.paste(conf_matrix_img, (0, 0))
        combined_image.paste(roc_img, (conf_matrix_img.width, 0))
        combined_image.paste(pr_curve_img, (roc_img.width+conf_matrix_img.width, 0))
        
        # Save the combined image
        combined_image.save(f'{run_path}/metrics_for_whole_test_dset_{comment}.png')

        return preds, (conf_matrix_fig, conf_matrix_ax), f1, precision, recall, accuracy, (pr_curve_fig, pr_curve_ax), (roc_fig, roc_ax)
    else: 
        return preds

In [52]:
metrics = test_model(f'{path}/runs/{timestamp}', trained_cnn, test_dataloader, comment ='3 classes')

Processing batches:   0%|          | 0/942 [00:00<?, ?it/s]

In [76]:
def per_shot_test(path, shots: list, results_df: pd.DataFrame):
    '''
    Takes model's results dataframe from confinement_mode_classifier.test_model() and shot numbers.
    Returns metrics of model for each shot separately

    Args: 
        shots: list with numbers of shot to be tested on.
        model: ResNet model
        results_df: pd.DataFrame from confinement_mode_classifier.test_model().
        time_confidence_img: Image with model confidence on separate shot
        roc_img: Image with ROC 
        conf_matrix_img: Image with confusion matrix
        combined_image: Combined image with three previous returns
    Returns:
        path: Path where images are saved
    '''

    for shot in tqdm(shots):
        pred_for_shot = results_df[results_df['shot']==shot]
        softmax_out = torch.nn.functional.softmax(torch.tensor(pred_for_shot[['L_logit','H_logit','ELM_logit']].values), dim=1)

        preds_tensor = torch.tensor(pred_for_shot['prediction'].values.astype(float))
        labels_tensor = torch.tensor(pred_for_shot['label'].values.astype(int))
        
        #Confusion matrix
        confusion_matrix_metric = MulticlassConfusionMatrix(num_classes=3)
        confusion_matrix_metric.update(preds_tensor, labels_tensor)
        conf_matrix_fig, conf_matrix_ax = confusion_matrix_metric.plot()
        

        #f1 score
        f1 = F1Score(task="multiclass", num_classes=3)(preds_tensor, labels_tensor)

        #Precision
        precision = MulticlassPrecision(num_classes=3)(preds_tensor, labels_tensor)

        #recall
        recall = MulticlassRecall(num_classes=3)(preds_tensor, labels_tensor)

        #accuracy
        accuracy = len(pred_for_shot[pred_for_shot['prediction']==pred_for_shot['label']])/len(pred_for_shot)

        textstr = '\n'.join((
            f'shot {shot}',
            r'threshhold = 0.5:',
            r'f1=%.2f' % (f1.item(), ),
            r'precision=%.2f' % (precision.item(), ),
            r'recall=%.2f' % (recall.item(), ),
            r'accuracy=%.2f' % (accuracy, )))
        # these are matplotlib.patch.Patch properties
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

        conf_time_fig, conf_time_ax = plt.subplots(figsize=(10,6))
        conf_time_ax.plot(pred_for_shot['time'],softmax_out[:,1], label='H-mode Confidence')
        conf_time_ax.plot(pred_for_shot['time'],-softmax_out[:,2], label='ELM Confidence')

        conf_time_ax.scatter(pred_for_shot[pred_for_shot['label']==1]['time'], 
                          len(pred_for_shot[pred_for_shot['label']==1])*[1], 
                          s=2, alpha=1, label='H-mode Truth', color='maroon')
        
        conf_time_ax.scatter(pred_for_shot[pred_for_shot['label']==2]['time'], 
                          len(pred_for_shot[pred_for_shot['label']==2])*[-1], 
                          s=2, alpha=1, label='ELM Truth', color='royalblue')
    
        conf_time_ax.text(0.05, 0.3, textstr, fontsize=14, verticalalignment='bottom', bbox=props)
        conf_time_ax.set_xlabel('t [ms]')
        conf_time_ax.set_ylabel('Confidence')

        plt.title(f'shot {shot}')
        conf_time_ax.legend()

        conf_matrix_ax.set_title(f'confusion matrix for shot {shot}')
        conf_matrix_fig.set_figheight(conf_time_fig.get_size_inches()[1])

        # Open the saved images using Pillow
        time_confidence_img = matplotlib_figure_to_pil_image(conf_time_fig)
        conf_matrix_img = matplotlib_figure_to_pil_image(conf_matrix_fig)

        combined_image = Image.new('RGB', (time_confidence_img.width + conf_matrix_img.width,
                                            time_confidence_img.height))

        # Paste the saved images into the combined image
        combined_image.paste(time_confidence_img, (0, 0))
        combined_image.paste(conf_matrix_img, (time_confidence_img.width, 0))

        # Save the combined image
        combined_image.save(f'{path}/metrics_for_shot_{shot}.png')

    return f'{path}/data'


def matplotlib_figure_to_pil_image(fig):
    """
    Convert a Matplotlib figure to a PIL Image.

    Parameters:
    - fig (matplotlib.figure.Figure): The Matplotlib figure to be converted.

    Returns:
    - PIL.Image.Image: The corresponding PIL Image.

    Example:
    >>> fig, ax = plt.subplots()
    >>> ax.plot([1, 2, 3, 4], [10, 5, 20, 15])
    >>> pil_image = matplotlib_figure_to_pil_image(fig)
    >>> pil_image.save("output_image.png")
    >>> pil_image.show()
    """
    # Create a FigureCanvasAgg to render the figure
    canvas = FigureCanvasAgg(fig)

    # Render the figure to a bitmap
    canvas.draw()

    # Get the RGB buffer from the bitmap
    buf = canvas.buffer_rgba()

    # Convert the buffer to a PIL Image
    image = Image.frombuffer("RGBA", canvas.get_width_height(), buf, "raw", "RGBA", 0, 1)

    return image


In [77]:
per_shot_test(f'{path}/runs/{timestamp}', [int(shot) for shot in shots_for_testing], metrics[0])

  0%|          | 0/12 [00:00<?, ?it/s]

'/compass/Shared/Users/bogdanov/vyzkumny_ukol/runs/24-02-21, 12-33-56 h_alpha, 1dCNN, lr=1e-3/data'