# Visualize Predictions

## Parameters

In [None]:
import torch

# set random seeds
torch.manual_seed(0)

In [38]:
config = '/home/Talen/foragefish_classifier/configs/exp_resnet18.yaml'
split = 'test'

## Load Data

In [39]:
import sys
sys.path.append('/home/Talen/foragefish_classifier')


In [None]:
import yaml
from train import create_dataloader, load_model       # NOTE: since we're using these functions across files, it could make sense to put them in e.g. a "util.py" script.

# load config
print(f'Using config "{config}"')
cfg = yaml.safe_load(open(config, 'r'))


# setup entities
dl_test = create_dataloader(cfg, split='test')

# load model
model, epoch = load_model(cfg)

## Visualize

This is up to you to figure out now. :)

In [None]:
from tqdm import trange

device = "cuda"
model.to(device) # puts model weights on to gpu
model.eval() # changes model to eval / inference mode

progressBar = trange(len(dl_test))
pred_all = []
argmax_all = []
img_list = []
for idx, (data, labels) in enumerate(dl_test):       # see the last line of file "dataset.py" where we return the image tensor (data) and label

    # put data and labels on device
    data, labels = data.to(device), labels.to(device)

    # forward pass
    prediction = model(data) 
    # visualize image that's stored in a batch in variable 'data' (this will be a for loop that iterates a batch)
    # use argmax() over the prediction in a single image, apply it to every image's corresponding prediction.
    #In a list, store labels. ANd then compare the g-t classes and prediction. Create a pandas dataframe with three columns: image_name, gt, and predictions. 
    
    # Now we use argmax() over the prediction pair of numbers, and apply it to every image's corresponding prediction.
    argmax = prediction.argmax(dim=1)

    print(argmax)
    
    # print(argmax)
    argmax_all.extend(argmax.detach().cpu().numpy())
    
    

    # store the prediction in a list
    # pred_all.append(prediction.detach().cpu().numpy()[0])
    pred_all.append(prediction.detach().cpu().numpy())
    img_list.extend(data)
    


# step 1 -visualize predictions + ground truth in matplotlib
# Step 2 - look up weights + biases, how to set them up in the model to log during training
# Step 3 - set up experiments so that when I start a new training run, it generates an experimental folder with the right name 
# copy config file to each experiment folder


In [None]:
img_list


# Now we use argmax() over the prediction pair of numbers, and apply it to every image's corresponding prediction.




In [None]:
# Now we make argmax into a list (same as pred_all) and print it:
print(argmax_all)

In [None]:


# Now we make a list of our ground-truth labels:
# gt_all = []
# for idx, (data, labels) in enumerate(dl_test):
#      gt_all.extend(labels.detach().cpu().numpy())

# print(gt_all)




In [None]:
# Now we make a pandas dataframe with three columns: image_name, gt, and predictions.
import pandas as pd
# df = pd.DataFrame({'gt': gt_all, 'pred':argmax_all, 'raw_pred': pred_all})
df = pd.DataFrame({'gt': gt_all, 'pred':argmax_all})
print(df)

In [None]:
# Now we visualize the predictions and ground truth in matplotlib

import matplotlib.pyplot as plt
import numpy as np

# Let's visualize the first 10 images
# for i in range(10):
#     plt.imshow(data[i].permute(1,2,0))
#     plt.title(f'Ground truth: {gt_all[i]}, Prediction: {argmax_all[i]}')
#     plt.show()

# Now lets visualize all images in the dataset, with their ground truth and predictions:
# for i in range(len(df)):
#     plt.imshow(data[i].permute(1,2,0))
#     plt.title(f'Ground truth: {gt_all[i]}, Prediction: {argmax_all[i]}')
#     plt.show()

# This visualizes one batch of images, but we want to visualize all images in the dataset.
for i in range(len(df)):
    # Now we visualize the images in a batch of 4 columns and 4 rows:
    plt.imshow(img_list[i].cpu().permute(1,2,0))
    plt.title(f'Ground truth: {gt_all[i]}, Prediction: {argmax_all[i]}')
    plt.show()


     

In [None]:
# Same thing now but in a smaller grid

def display_batch(img_list, gt_all, argmax_all, start_idx=0):
    # Create 4x3 grid with adjusted figure size
    fig, axes = plt.subplots(4, 3, figsize=(12, 16))
    axes = axes.ravel()
    
    # Display up to 12 images per batch
    for i in range(12):
        idx = start_idx + i
        if idx >= len(img_list):
            break
            
        # Display image and labels
        axes[i].imshow(img_list[idx].cpu().permute(1,2,0))
        axes[i].set_title(f'Image {idx}\nGT: {gt_all[idx]}\nPred: {argmax_all[idx]}')
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Display all images in batches of 12
for batch_start in range(0, len(df), 12):
    display_batch(img_list, gt_all, argmax_all, batch_start)

In [None]:
# Next steps: visualize and evaluate
# Visualize - look @ images and compare to prediction we got.

# Now we want to plot a histogram visualizing the distribution of the confidence in each image by our model, seperated by empty and forage fish images:
import seaborn as sns
sns.histplot(df['gt'], color='blue', alpha=0.5, label='Ground Truth')
sns.histplot(df['pred'], color='red', alpha=0.5, label='Prediction')
plt.legend()
# first we need to conda install seaborn:
# conda install seaborn
import matplotlib.pyplot as plt
import seaborn as sns
sns.histplot(df['gt'], color='blue', alpha=0.5, label='Ground Truth')
sns.histplot(df['pred'], color='red', alpha=0.5, label='Prediction')
plt.legend()
plt.show()


# These should have different colours with an alpha value to see the overlap.
# We will use the matplotlib library to plot this.




#Now we want to visualize the predictions and ground truth in matplotlib in a plot:
# plt.plot(df['gt'], label='Ground Truth')
# plt.plot(df['pred'], label='Prediction')
# #give the x-axis a label
# plt.xlabel('Image')
# #give the y-axis a label
# plt.ylabel('Class')
# plt.legend()
# plt.show()

# import sklearn
# print(sklearn.__version__)


# # Now print the accuracy of the model:
# from sklearn.metrics import accuracy_score
# accuracy = accuracy_score(df['gt'], df['pred'])
# print(f'Accuracy: {accuracy}')
