In [None]:
import torch
import numpy as np
import sys
from baseline_train import *


def main(k_shot):
    # Set seed for reproducibility
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    n_way = 3
    k_query = 16
    num_episodes = 200
    num_workers = 12
    bs = 4
    root = '../../../../scratch/rl80/mimic-cxr-jpg-2.0.0.physionet.org/files'
    path_splits = '../splits/splits.csv'  # Location of preprocessed splits
    path_results = f'../../results/{k_shot}shot'  # Folder to save the CSV results
    path_pretrained = '../results/basic/basic_39.pth'

    # Set device to GPU if it exists
    torch.cuda.set_device(0)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()

    # Load in data
    dataset = MimicCxrJpgEpisodes(root, path_splits, n_way, k_shot, k_query, num_episodes, 'novel')
    loader = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=num_workers)

    # Create Dataframe to export results to CSV
    df_results = pd.DataFrame(columns=['Validation Loss', 'Accuracy', 'Macro Accuracy',
                                       'Macro-F1 Score'] + [str(x) + ' F1' for x in range(n_way)])
    df_hold = pd.DataFrame(columns=['Validation Loss', 'Accuracy', 'Macro Accuracy',
                                    'Macro-F1 Score'] + [str(x) + ' F1' for x in range(n_way)])

    # Iterate through batched episodes. One episode is one experiment
    for step, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(loader):
        # Convert Tensors to appropriate device
        batch_support_x, batch_support_y = support_imgs.to(device), support_labels.to(device)
        batch_query_x, batch_query_y = query_imgs.to(device), query_labels.to(device)

        # [num_batch, training_sz, channels, height, width] = support_x.size()
        # num_batch = num of episodes
        # training_sz = size of support or query set
        num_batch = batch_support_x.size(0)  # Number of episodes in the batch

        # Break down the batch of episodes into single episodes
        for i in range(num_batch):
            # Load in model and reset weights every episode/experiment
            model = CosineSimilarityNet(n_way).to(device)
            pretrained_dict = torch.load(path_pretrained)
            del pretrained_dict['linear.weight']  # Remove the last layer
            del pretrained_dict['linear.bias']
            model_dict = model.state_dict()
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            # Break down the sets into individual episodes
            support_x, support_y = batch_support_x[i], batch_support_y[i]
            query_x, query_y = batch_query_x[i], batch_query_y[i]

            # Find Average Features
            model.eval()
            with torch.no_grad():
                # Initialise list containing features sorted by class
                label_features = [torch.FloatTensor([]).to(device) for i in range(n_way)]

                # Initialise weight for the nearest centroid, last layer weight
                fc_weight = torch.FloatTensor([]).to(device)

                # Get Features
                _, features = model(support_x, extract_features=True)

                # Sort features by labels to be averaged later on
                for i in range(features.size(0)):
                    label = support_y[i]
                    label_features[label] = torch.cat((label_features[label], features[i][None]))

                # Create weight for the last layer
                for j in range(n_way):
                    feature_avg = torch.mean(label_features[j], 0)
                    fc_weight = torch.cat((fc_weight, feature_avg[None]), 0)

                # Apply weight to the model
                nc_dict = model.state_dict()
                nc_dict['cos_sim.weight'] = fc_weight
                model.load_state_dict(nc_dict)

            # Testing
            val_loss, acc, m_acc, macro_f1, class_f1 = test(query_x, query_y, model, criterion, device, n_way)

            # Print the results per experiment
            print(f'[v_loss: {val_loss:.5f} val_acc: {acc:.5f} val_m_acc: {m_acc:.5f} f1: {macro_f1:.5f}')

            # Record the experiment to be saved into a CSV
            df_hold.loc[0] = [val_loss, acc, m_acc, macro_f1] + class_f1
            df_results = df_results.append(df_hold.loc[0], ignore_index=True)

    # Create results folder if it does not exist
    if not os.path.exists(path_results):
        os.makedirs(path_results)

    df_results.to_csv(os.path.join(path_results, f'{k_shot}shot_nc.csv'), index=False)  # Export results to a CSV file


if __name__ == '__main__':
    #print(f'NC Training {sys.argv[1]} shot')
    #main(int(sys.argv[1]))  # Get the k_shot variable from command line
    main(1)
    main(3)
    main(5)
    main(10)
    main(20)


[v_loss: 1.09525 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.00000
[v_loss: 1.10243 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.33184
[v_loss: 1.09105 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.33434
[v_loss: 1.09488 val_acc: 0.43750 val_m_acc: 0.43750 f1: 0.43459
[v_loss: 1.09399 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.28654
[v_loss: 1.09397 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35659
[v_loss: 1.10067 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.32194
[v_loss: 1.10131 val_acc: 0.25000 val_m_acc: 0.25000 f1: 0.00000
[v_loss: 1.09945 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.27666
[v_loss: 1.08890 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.36150
[v_loss: 1.10344 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.37267
[v_loss: 1.09031 val_acc: 0.45833 val_m_acc: 0.45833 f1: 0.43407
[v_loss: 1.09506 val_acc: 0.45833 val_m_acc: 0.45833 f1: 0.42097
[v_loss: 1.09878 val_acc: 0.29167 val_m_acc: 0.29167 f1: 0.28472
[v_loss: 1.09857 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.00000
[v_loss: 1.09447 val_acc:

[v_loss: 1.09583 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.00000
[v_loss: 1.09264 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.27531
[v_loss: 1.09435 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.29860
[v_loss: 1.10668 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.29965
[v_loss: 1.09520 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.30935
[v_loss: 1.09867 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.32961
[v_loss: 1.09575 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.33025
[v_loss: 1.10318 val_acc: 0.25000 val_m_acc: 0.25000 f1: 0.22621
[v_loss: 1.09763 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.38038
[v_loss: 1.09548 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.30476
[v_loss: 1.10276 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.26773
[v_loss: 1.09518 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.00000
[v_loss: 1.10298 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.23175
[v_loss: 1.09702 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.32124
[v_loss: 1.10210 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.28864
[v_loss: 1.10766 val_acc:

[v_loss: 1.09549 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.30028
[v_loss: 1.10276 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.32050
[v_loss: 1.09771 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.37592
[v_loss: 1.09484 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.38168
[v_loss: 1.10175 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35570
[v_loss: 1.10424 val_acc: 0.20833 val_m_acc: 0.20833 f1: 0.19776
[v_loss: 1.09830 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.36485
[v_loss: 1.09517 val_acc: 0.47917 val_m_acc: 0.47917 f1: 0.00000
[v_loss: 1.10063 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.34026
[v_loss: 1.10150 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.00000
[v_loss: 1.09718 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.40720
[v_loss: 1.09709 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.38966
[v_loss: 1.09701 val_acc: 0.29167 val_m_acc: 0.29167 f1: 0.29057
[v_loss: 1.09719 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.34365
[v_loss: 1.10107 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.36483
[v_loss: 1.09792 val_acc:

[v_loss: 1.10323 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.25840
[v_loss: 1.09727 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.34803
[v_loss: 1.10417 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.25008
[v_loss: 1.09889 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.39820
[v_loss: 1.09866 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.29955
[v_loss: 1.10300 val_acc: 0.27083 val_m_acc: 0.27083 f1: 0.00000
[v_loss: 1.09453 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35610
[v_loss: 1.10008 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.30267
[v_loss: 1.09187 val_acc: 0.43750 val_m_acc: 0.43750 f1: 0.42197
[v_loss: 1.09575 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.34554
[v_loss: 1.09590 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.34466
[v_loss: 1.09562 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.40476
[v_loss: 1.09935 val_acc: 0.43750 val_m_acc: 0.43750 f1: 0.43533
[v_loss: 1.08885 val_acc: 0.47917 val_m_acc: 0.47917 f1: 0.47781
[v_loss: 1.09782 val_acc: 0.29167 val_m_acc: 0.29167 f1: 0.24622
[v_loss: 1.09745 val_acc:

[v_loss: 1.09705 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.35738
[v_loss: 1.09587 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.30556
[v_loss: 1.09595 val_acc: 0.43750 val_m_acc: 0.43750 f1: 0.44413
[v_loss: 1.09820 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35043
[v_loss: 1.10068 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.27697
[v_loss: 1.09414 val_acc: 0.45833 val_m_acc: 0.45833 f1: 0.43507
[v_loss: 1.09705 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35455
[v_loss: 1.09486 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.00000
[v_loss: 1.09311 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.40617
[v_loss: 1.09519 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.31729
[v_loss: 1.09026 val_acc: 0.45833 val_m_acc: 0.45833 f1: 0.44388
[v_loss: 1.10017 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.31435
[v_loss: 1.09814 val_acc: 0.35417 val_m_acc: 0.35417 f1: 0.35610
[v_loss: 1.09787 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.37664
[v_loss: 1.10044 val_acc: 0.25000 val_m_acc: 0.25000 f1: 0.24735
[v_loss: 1.09877 val_acc:

[v_loss: 1.09216 val_acc: 0.50000 val_m_acc: 0.50000 f1: 0.49974
[v_loss: 1.09583 val_acc: 0.39583 val_m_acc: 0.39583 f1: 0.39436
[v_loss: 1.09503 val_acc: 0.45833 val_m_acc: 0.45833 f1: 0.45238
[v_loss: 1.10039 val_acc: 0.22917 val_m_acc: 0.22917 f1: 0.23310
[v_loss: 1.10507 val_acc: 0.18750 val_m_acc: 0.18750 f1: 0.17965
[v_loss: 1.09301 val_acc: 0.47917 val_m_acc: 0.47917 f1: 0.47588
[v_loss: 1.09834 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.40073
[v_loss: 1.09465 val_acc: 0.43750 val_m_acc: 0.43750 f1: 0.41429
[v_loss: 1.09945 val_acc: 0.33333 val_m_acc: 0.33333 f1: 0.32479
[v_loss: 1.09510 val_acc: 0.41667 val_m_acc: 0.41667 f1: 0.40801
[v_loss: 1.09610 val_acc: 0.37500 val_m_acc: 0.37500 f1: 0.36883
[v_loss: 1.09077 val_acc: 0.54167 val_m_acc: 0.54167 f1: 0.54297
[v_loss: 1.09768 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.30620
[v_loss: 1.09824 val_acc: 0.31250 val_m_acc: 0.31250 f1: 0.31208
[v_loss: 1.09009 val_acc: 0.56250 val_m_acc: 0.56250 f1: 0.55965
[v_loss: 1.09692 val_acc:

In [None]:
import torch
import numpy as np
import sys
from baseline_train import *


def main(k_shot):
    # Set seed for reproducibility
    torch.manual_seed(222)
    torch.cuda.manual_seed_all(222)
    np.random.seed(222)

    n_way = 3
    k_query = 16
    num_episodes = 200
    num_workers = 12
    bs = 4
    root = '../../../../scratch/rl80/mimic-cxr-jpg-2.0.0.physionet.org/files'
    path_splits = '../splits/splits.csv'  # Location of preprocessed splits
    path_results = f'../../results/{k_shot}shot'  # Folder to save the CSV results
    path_pretrained = '../results/basic_cosine/basic_cosine_49.pth'

    # Set device to GPU if it exists
    torch.cuda.set_device(0)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = nn.CrossEntropyLoss()

    # Load in data
    dataset = MimicCxrJpgEpisodes(root, path_splits, n_way, k_shot, k_query, num_episodes, 'novel')
    loader = DataLoader(dataset, batch_size=bs, shuffle=True, num_workers=num_workers)

    # Create Dataframe to export results to CSV
    df_results = pd.DataFrame(columns=['Validation Loss', 'Accuracy', 'Macro Accuracy',
                                       'Macro-F1 Score'] + [str(x) + ' F1' for x in range(n_way)])
    df_hold = pd.DataFrame(columns=['Validation Loss', 'Accuracy', 'Macro Accuracy',
                                    'Macro-F1 Score'] + [str(x) + ' F1' for x in range(n_way)])

    # Iterate through batched episodes. One episode is one experiment
    for step, (support_imgs, support_labels, query_imgs, query_labels) in enumerate(loader):
        # Convert Tensors to appropriate device
        batch_support_x, batch_support_y = support_imgs.to(device), support_labels.to(device)
        batch_query_x, batch_query_y = query_imgs.to(device), query_labels.to(device)

        # [num_batch, training_sz, channels, height, width] = support_x.size()
        # num_batch = num of episodes
        # training_sz = size of support or query set
        num_batch = batch_support_x.size(0)  # Number of episodes in the batch

        # Break down the batch of episodes into single episodes
        for i in range(num_batch):
            # Load in model and reset weights every episode/experiment
            model = CosineSimilarityNet(n_way).to(device)
            pretrained_dict = torch.load(path_pretrained)
            del pretrained_dict['cos_sim.weight'] # Remove the last layer
            model_dict = model.state_dict()
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            # Break down the sets into individual episodes
            support_x, support_y = batch_support_x[i], batch_support_y[i]
            query_x, query_y = batch_query_x[i], batch_query_y[i]

            # Find Average Features
            model.eval()
            with torch.no_grad():
                # Initialise list containing features sorted by class
                label_features = [torch.FloatTensor([]).to(device) for i in range(n_way)]

                # Initialise weight for the nearest centroid, last layer weight
                fc_weight = torch.FloatTensor([]).to(device)

                # Get Features
                _, features = model(support_x, extract_features=True)

                # Sort features by labels to be averaged later on
                for i in range(features.size(0)):
                    label = support_y[i]
                    label_features[label] = torch.cat((label_features[label], features[i][None]))

                # Create weight for the last layer
                for j in range(n_way):
                    feature_avg = torch.mean(label_features[j], 0)
                    fc_weight = torch.cat((fc_weight, feature_avg[None]), 0)

                # Apply weight to the model
                nc_dict = model.state_dict()
                nc_dict['cos_sim.weight'] = fc_weight
                model.load_state_dict(nc_dict)

            # Testing
            val_loss, acc, m_acc, macro_f1, class_f1 = test(query_x, query_y, model, criterion, device, n_way)

            # Print the results per experiment
            #print(f'[v_loss: {val_loss:.5f} val_acc: {acc:.5f} val_m_acc: {m_acc:.5f} f1: {macro_f1:.5f}')

            # Record the experiment to be saved into a CSV
            df_hold.loc[0] = [val_loss, acc, m_acc, macro_f1] + class_f1
            df_results = df_results.append(df_hold.loc[0], ignore_index=True)

    # Create results folder if it does not exist
    if not os.path.exists(path_results):
        os.makedirs(path_results)
        
    print('saved')
    df_results.to_csv(os.path.join(path_results, f'{k_shot}shot_nc_cs.csv'), index=False)  # Export results to a CSV file


if __name__ == '__main__':
    #print(f'NC_CS Training {sys.argv[1]} shot')
    #main(int(sys.argv[1]))  # Get the k_shot variable from command line
    main(1)
    main(3)
    main(5)
    main(10)
    main(20)