# Imports

In [None]:
from matplotlib import image as mpimg
import tensorflow as tf
import torch
import os
import re
import requests
from PIL import Image
from torchvision.transforms import ToTensor
from tensorflow import keras
import matplotlib.pyplot as plt     # to plot charts
import numpy as np
import pandas as pd                 # for data manipulation
pd.set_option('display.max_colwidth', None)
import cv2                          # for image processing
from io import BytesIO
from tabulate import tabulate       # to print pretty tables
import seaborn as sns
import shutil

# sklearn imports for metrics and dataset splitting
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# keras imports for image preprocessing
from keras.preprocessing.image import ImageDataGenerator

# huggingface imports for model building 
import torch.nn as nn
from transformers import ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \
  default_data_collator, EarlyStoppingCallback, ViTConfig, AutoImageProcessor, ViTImageProcessor, AutoModel 
from transformers.modeling_outputs import SequenceClassifierOutput

# keras imports for early stoppage and model checkpointing
from torchvision.transforms import ToTensor, Resize
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D, Dataset
import datasets

from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler

from safetensors.numpy import  load_file

In [None]:
DATASET_URL = 'https://data.lhncbc.nlm.nih.gov/public/Pills/'
directory = "dataset"
train_dir = "training20_set"
test_dir = "testing20_set"
img_height = 224
img_width = 224
batch_size = 64

# Import (Top 20)

In [None]:
# load the csv files
csv_file_top20 = "./top20.csv"
csv_file = "./table.csv"
csv_file2 = "./directory_consumer_grade_images.xlsx"
top20_df = pd.read_csv(csv_file_top20)
table_df = pd.read_csv(csv_file)
directory_df = pd.read_excel(csv_file2)

top20_list = top20_df['Name'].tolist()

In [None]:
# remove .wmv files
directory_df = directory_df[~directory_df['Image'].str.contains('.wmv', case=False, na=False)]

In [None]:
# Create a function to get the base label
def get_base_label(label):
    for item in top20_list:
        if item.lower() in label.lower():
            return item
    return label

In [None]:
# find the top 20 medications in the two datasets
# find matches in table_df
matches_in_table_df = pd.DataFrame()
for item in top20_list:
    item = item.upper()
    matches = table_df[table_df['name'].fillna('').str.contains(item, case=False, na=False) ]
    matches_in_table_df = pd.concat([matches_in_table_df, matches])

# find matches in directory_df
matches_in_directory_df = pd.DataFrame()
for item in top20_list:
    item = item.upper()
    matches = directory_df[directory_df['Name'].fillna('').str.contains(item, case=False, na=False) ]
    matches_in_directory_df = pd.concat([matches_in_directory_df, matches])

# generate the test set
test_df = matches_in_directory_df[matches_in_directory_df['Layout'] == 'C3PI_Test']

# keep only necessary images
matches_in_directory_df = matches_in_directory_df[matches_in_directory_df['Layout'].isin(['MC_API_NLMIMAGE_V1.3', 'MC_CHALLENGE_V1.0'])]

# remove unnecessary columns and rename columns
matches_in_table_df = matches_in_table_df[['name', 'nlmImageFileName']]
matches_in_table_df = matches_in_table_df.rename(columns={'name': 'labels', 'nlmImageFileName': 'image_paths'})
matches_in_directory_df = matches_in_directory_df[['Image', 'Name']]
matches_in_directory_df = matches_in_directory_df.rename(columns={'Image': 'image_paths', 'Name': 'labels'})
test_df = test_df[['Image', 'Name']]
test_df = test_df.rename(columns={'Image': 'image_paths', 'Name': 'labels'})

# # Remove .wmv files
# matches_in_table_df = matches_in_table_df[~matches_in_table_df['image_paths'].str.contains('.wmv', case=False, na=False, regex=True)]
# matches_in_directory_df = matches_in_directory_df[~matches_in_directory_df['image_paths'].str.contains('.wmv', case=False, na=False, regex=True)]

# add a base label column for the top 20 medications
matches_in_table_df['base_label'] = matches_in_table_df['labels'].apply(get_base_label)
matches_in_directory_df['base_label'] = matches_in_directory_df['labels'].apply(get_base_label)
test_df['base_label'] = test_df['labels'].apply(get_base_label)

# instantiate the label encoder
encoder = LabelEncoder()
encoder.fit(top20_list)

# Transform the data in the dataframes
matches_in_table_df['base_label'] = encoder.transform(matches_in_table_df['base_label'])
matches_in_directory_df['base_label'] = encoder.transform(matches_in_directory_df['base_label'])
test_df['base_label'] = encoder.transform(test_df['base_label'])

top20_instances_df = pd.concat([matches_in_table_df, matches_in_directory_df])

In [None]:
# temp_df = test_df.copy()
# test_df = top20_instances_df.copy()
# top20_instances_df = temp_df

In [None]:
top20_instances_df.head()

In [None]:
print('training set size: ',top20_instances_df.size)
print('test set size: ',test_df.size)
print('number of unique labels: ',len(top20_instances_df['base_label'].unique()))

In [None]:
top20_instances_df.head()

In [None]:
test_df.head()

In [None]:
# Decode the labels in top20_instances_df
top20_instances_df['base_label'] = encoder.inverse_transform(top20_instances_df['base_label'])

# Check if the data is imbalanced in the training set
train_label_counts = top20_instances_df['base_label'].value_counts()
print(train_label_counts)

# Plot the label counts
plt.figure(figsize=(10,6))
plt.bar(train_label_counts.index, train_label_counts.values, alpha=0.5, color='g')
plt.title('Distribution of Base Labels (Training Set)')
plt.xlabel('Base Label')
plt.ylabel('Number of Labels')
plt.xticks(rotation=90)
plt.grid(True)
plt.show()

# Decode the labels in test_df
test_df['base_label'] = encoder.inverse_transform(test_df['base_label'])

# Check if the data is imbalanced in the test set
test_label_counts = test_df['base_label'].value_counts()
print(test_label_counts)

# Plot the label counts for the test set
plt.figure(figsize=(10,6))
plt.bar(test_label_counts.index, test_label_counts.values, alpha=0.5, color='b')
plt.title('Distribution of Base Labels (Test Set)')
plt.xlabel('Base Label')
plt.ylabel('Number of Labels')
plt.xticks(rotation=90)
plt.grid(True)
plt.show()

In [None]:
# Re-encode the labels
top20_instances_df['base_label'] = encoder.transform(top20_instances_df['base_label'])
test_df['base_label'] = encoder.transform(test_df['base_label'])

# Downloading the Training Data

In [None]:
website_url = 'https://data.lhncbc.nlm.nih.gov/public/Pills/'
dataset_dir = './dataset'
training_dir = './training20_set'

# Make sure the training directory exists
if not os.path.exists(training_dir):
    os.makedirs(training_dir)

# Function to download an image from a URL and save it to a directory
def download_image(url, save_path):
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(save_path, 'wb') as f:
            response.raw.decode_content = True
            shutil.copyfileobj(response.raw, f)
        return True
    else:
        print(f"Failed to download image from {url}")
        return False

for index, row in top20_instances_df.iterrows():
    file_name = row['image_paths']
    save_path = os.path.join(training_dir, os.path.basename(file_name))
    if os.path.exists(os.path.join(dataset_dir, file_name)):
        shutil.copy(os.path.join(dataset_dir, file_name), save_path)
    elif not os.path.exists(save_path):  # Check if file already exists in the target directory
        url = website_url + file_name
        if download_image(url, save_path):
            print(f"Downloaded {file_name} from {url}")
        else:
            print(f"Failed to find {file_name} in dataset_dir and download from {url}")
    else:
        print(f"{file_name} already exists in {training_dir}, skipping download.")


# Downloading the Test Data

In [None]:
testing_dir = './testing20_set'

# Make sure the testing directory exists
if not os.path.exists(testing_dir):
    os.makedirs(testing_dir)

for index, row in test_df.iterrows():
    file_name = row['image_paths']
    
    # Check if the file ends with ".wmv", if so, skip it
    if file_name.endswith('.WMV'):
        print(f"Skipping {file_name} as it has the .wmv extension")
        continue
    
    save_path = os.path.join(testing_dir, os.path.basename(file_name))
    if os.path.exists(os.path.join(dataset_dir, file_name)):
        shutil.copy(os.path.join(dataset_dir, file_name), save_path)
    elif not os.path.exists(save_path):  # Check if file already exists in the target directory
        url = website_url + file_name
        if download_image(url, save_path):
            print(f"Downloaded {file_name} from {url}")
        else:
            print(f"Failed to find {file_name} in dataset_dir and download from {url}")
    else:
        print(f"{file_name} already exists in {testing_dir}, skipping download. It was downloaded previously.")


In [None]:
# print the size of the dataset
print('Number of files in the training set: ', len(os.listdir('./training20_set')))
print('Number of files in the test set: ', len(os.listdir('./testing20_set')))

# Helper Functions

In [None]:
# Function to convert an image file to a tensor
def image_to_tensor(image_file):
    image = Image.open(image_file)
    image = Resize((224, 224))(image)
    return ToTensor()(image)

#calculate the weights
def get_weight(class_num, label_count):
    weights = 1 / np.log(label_count)
    weights = class_num * weights/np.sum(weights)
    return weights

def add_class_weights(input_data):
    #get the number of labels
    result_data = input_data
    label_num = len(result_data['labels'].unique())
    
    #Create a Pandas dataframe for weight caculation
    value = result_data.value_counts('labels').tolist()
    value_df = pd.DataFrame({'labels': result_data.value_counts('labels').index.tolist(), 'counts':result_data.value_counts('labels').tolist()})
    
    base = 2
    value_df['counts'] = get_weight(label_num, base*value_df['counts'])
    # value_df
    list = value_df.set_index('labels').T.to_dict('list')
    
    for index, row in result_data.iterrows():
        result_data.loc[index, ('weights')] = list[result_data.loc[index, ('labels')]][0]
    
    return result_data

# Preview the Data

In [None]:
# retrieve the base paths
test_df['image_paths'] = test_df['image_paths'].apply(os.path.basename)
top20_instances_df['image_paths'] = top20_instances_df['image_paths'].apply(os.path.basename)

In [None]:
# print the first 5 image paths and decoded labels for the training dataset
for index, row in top20_instances_df.head(5).iterrows():
    image = row['image_paths']
    label = row['base_label']
    print("Image:", image)
    print("Label:", encoder.inverse_transform([label])[0])
    print()

# display the first 9 images and their labels
plt.figure(figsize=(10, 10))
for i, (index, row) in enumerate(top20_instances_df.head(9).iterrows()):
    image_path = os.path.join(training_dir, row['image_paths'])  # Append training_dir to the beginning of the path
    label = row['base_label']
    ax = plt.subplot(3, 3, i + 1)
    
    # Open the image file
    with Image.open(image_path) as img:
        plt.imshow(img)
    
    plt.title(encoder.inverse_transform([label])[0])
    plt.axis("off")

In [None]:
# print the first 5 image paths and decoded labels for the test dataset
for index, row in test_df.head(5).iterrows():
    image = row['image_paths']
    label = row['base_label']
    print("Image:", image)
    print("Label:", encoder.inverse_transform([label])[0])
    print()

# display the first 9 images and their labels
plt.figure(figsize=(10, 10))
for i, (index, row) in enumerate(test_df.head(9).iterrows()):
    image_file_name = os.path.basename(row['image_paths'])  
    image_path = os.path.join(test_dir, image_file_name)  
    label = row['base_label']
    ax = plt.subplot(3, 3, i + 1)
    
    # Open the image file
    with Image.open(image_path) as img:
        plt.imshow(img)
    
    plt.title(encoder.inverse_transform([label])[0])
    plt.axis("off")

# Data Augmentation

In [None]:
#Convert column into strings
top20_instances_df["image_paths"] = top20_instances_df["image_paths"].astype(str)
top20_instances_df["base_label"] = top20_instances_df["base_label"].astype(str)

train_df = top20_instances_df.copy()
test_df["image_paths"] = test_df["image_paths"].astype(str)
test_df["base_label"] = test_df["base_label"].astype(str)

# train_df, test_df = train_test_split(test_df, test_size=0.2, random_state=42)
test_df, eval_df = train_test_split(test_df, test_size=0.2, random_state=42)

In [None]:
# Create the image data generator for the training set
imageTrain_data = ImageDataGenerator(
    rescale = 1./255.,
    rotation_range = 60,
    shear_range = 0.3,
    zoom_range = 0.5,
    width_shift_range = 0.3,
    height_shift_range = 0.3,
    fill_mode="nearest",
)

train_generator = imageTrain_data.flow_from_dataframe(
    dataframe=train_df,
    directory=train_dir,
    target_size = (img_height, img_width),
    batch_size = batch_size,
    x_col = "image_paths",
    y_col = "base_label",
    class_mode="categorical",
)


# Create the image data generator for the evaluation set
imageEval_data = ImageDataGenerator(rescale = 1./255.)

eval_generator = imageEval_data.flow_from_dataframe(
    dataframe=eval_df,
    directory=test_dir,
    target_size = (img_height, img_width),
    batch_size = batch_size,
    x_col = "image_paths",
    y_col = "base_label",
    class_mode="categorical",
)


# Create the image data generator for the test set
imageTest_data = ImageDataGenerator(rescale = 1./255.)

test_generator = imageTest_data.flow_from_dataframe(
    dataframe=test_df,
    directory=test_dir,
    target_size = (img_height, img_width),
    batch_size = batch_size,
    x_col = "image_paths",
    y_col = "base_label",
    class_mode="categorical",
)

#Display example of image augmentation
sample_dataframe = test_df.sample(n=1).reset_index(drop=True)
sample_generator = imageTrain_data.flow_from_dataframe(
    dataframe=sample_dataframe,
    directory=test_dir,
    target_size = (img_height, img_width),
    batch_size = batch_size,
    x_col = "image_paths",
    y_col = "base_label",
    class_mode="categorical",
)

plt.figure(figsize=(12, 12))
for i in range (0, 15):
  ax = plt.subplot(5, 3, i + 1)
  for X_column, Y_column in sample_generator:
    plt.imshow(X_column[0])
    break
plt.tight_layout()
plt.show()

# Model Training

In [None]:
all_labels = pd.concat([train_df['base_label'], eval_df['base_label'], test_df['base_label']])
# num_labels = all_labels.nunique() - 1
num_labels = 20

In [None]:
# Check if CUDA is available and set the device accordingly
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# # Compute class weights
# class_counts = np.bincount(train_labels)
# class_weights = 1. / class_counts
# class_weights = class_weights / np.sum(class_weights) * len(class_counts)
# class_weights = torch.FloatTensor(class_weights).to(device)
    
class ViTForImageClassification(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # self.vit = ViTModel(config, add_pooling_layer=False)
        # self.classifier = nn.Linear(config.hidden_size, config.num_labels)
        self.vit = ViTModel(config, add_pooling_layer=False).to(device)                # Move model to Nvidia card
        self.classifier = nn.Linear(config.hidden_size, config.num_labels).to(device)  # Move model to Nvidia card

    def forward(self, pixel_values, labels):
        outputs = self.vit(pixel_values=pixel_values)
        logits = self.classifier(outputs.last_hidden_state[:, 0])
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            # loss_fct = nn.CrossEntropyLoss(weight=class_weights)
            loss = loss_fct(logits, labels)
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    

# compute accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    if isinstance(labels, int):
        labels = [labels]
    accuracy = load_metric("accuracy").compute(predictions=predictions, references=labels)
    print(f"Accuracy: {accuracy}")
    return accuracy
    
        
# create feature extractor to tokenize data
feature_extractor = ViTImageProcessor(
    image_size=224,
    do_resize=True,
    do_normalize=True,
    do_rescale=False,
    image_mean=[0.5, 0.5, 0.5],
    image_std=[0.5, 0.5, 0.5],
)


# Define a function to load and preprocess the images
def load_and_preprocess_images(example, directory):
    # Load the image from the file
    image = Image.open(directory + example['image_paths'])
    image = np.array(image, dtype=np.uint8)
    image = np.moveaxis(image, source=-1, destination=0)
    # Preprocess the image
    inputs = feature_extractor(images=[image])
    pixel_values = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32).to(device)  # convert to tensor and move to device
    label = int(example['labels'])
    return {'pixel_values': pixel_values, 'labels': label}


# define a custom data collator
def data_collator(features):
    pixel_values = [torch.tensor(feature['pixel_values'], dtype=torch.float32).to(device) for feature in features]  # Move to device
    labels = [feature['labels'] for feature in features]
    pixel_values = torch.stack(pixel_values)
    return {'pixel_values': pixel_values, 'labels': torch.tensor(labels).to(device)}  # Move to device


# Define the features of the dataset
features = Features({
    'labels': ClassLabel(num_classes=num_labels),
    'img': Array3D(dtype="int64", shape=(3, 32, 32)),
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
})

# Generate lists of image paths and labels for training dataset
train_image_paths = train_df["image_paths"].tolist()
train_labels = train_df["base_label"].tolist()

# Create a dictionary with the image paths and labels
train_dict = {'image_paths': train_image_paths, 'labels': train_labels}

# Create the dataset
train_dataset = Dataset.from_dict(train_dict)

# Apply the function to the dataset
# train_dataset = train_dataset.map(lambda example: load_and_preprocess_images(example, 'testing20_set/'))
train_dataset = train_dataset.remove_columns(['image_paths'])


# Repeat the same process for the evaluation and test datasets
eval_image_paths = eval_df["image_paths"].tolist()
eval_labels = eval_df["base_label"].tolist()
eval_dict = {'image_paths': eval_image_paths, 'labels': eval_labels}
eval_dataset = Dataset.from_dict(eval_dict)
# eval_dataset = eval_dataset.map(lambda example: load_and_preprocess_images(example, 'testing20_set/'))
eval_dataset = eval_dataset.remove_columns(['image_paths'])


test_image_paths = test_df["image_paths"].tolist()
test_labels = test_df["base_label"].tolist()
test_dict = {'image_paths': test_image_paths, 'labels': test_labels}
test_dataset = Dataset.from_dict(test_dict)
# test_dataset = test_dataset.map(lambda example: load_and_preprocess_images(example, 'testing20_set/'))
test_dataset = test_dataset.remove_columns(['image_paths'])


# Load the pre-trained model
pretrained_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

# Define your custom model
config = pretrained_model.config
config.num_labels = num_labels
model = ViTForImageClassification(config)

# Copy the pre-trained weights to your custom model
model.vit = pretrained_model

early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=10,  # Number of evaluations with no improvement after which training will be stopped.
    early_stopping_threshold=0.0  # Threshold for measuring the new optimum, to only focus on significant changes.
)

# create the training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=20,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=16,   # batch size for evaluation
    warmup_steps=75,                # number of warmup steps for learning rate scheduler
    weight_decay=0.018,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=5,
    logging_first_step=True,
    logging_strategy='steps',
    evaluation_strategy='epoch',
    eval_steps=10,  
    save_strategy='epoch',
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    greater_is_better=True,
    learning_rate=5e-5,
    gradient_accumulation_steps=1,      # prevents vanishing/exploding gradients
    max_grad_norm=1.0,                  # prevents vanishing/exploding gradients
    # fp16=True                     # mixed precision training; enable if using nVidia graphics cards
)

class CustomTrainer(Trainer):
    def get_train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.args.train_batch_size, shuffle=True, collate_fn=self.data_collator)
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Move inputs to device
        for key, value in inputs.items():
            inputs[key] = value.to(device)

        outputs = model(**inputs)
        logits = outputs.logits

        labels = inputs["labels"]  # Get labels from inputs

        loss = torch.nn.functional.cross_entropy(logits, labels)
        return (loss, outputs) if return_outputs else loss
        
mainTrainer = CustomTrainer (
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping_callback]  # Add the early stopping callback
)

In [None]:
%
mainTrainer.train()
mainTrainer.save_model('./saved_model')

# Model Testing

Download the model from HuggingFace under the directory ./saved_model

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("./saved_model1 - best/model.safetensors", framework="pt", device='cpu') as f:
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

In [None]:
# Define a function to load and preprocess the images for testing
def load_and_preprocess_images(example, directory):
    image = Image.open(directory + example['image_paths'])
    image = np.array(image, dtype=np.uint8)
    image = np.moveaxis(image, source=-1, destination=0)
    inputs = feature_extractor(images=[image])
    pixel_values = torch.tensor(inputs['pixel_values'][0], dtype=torch.float32).to(device)
    label = int(example['labels'])
    return {'pixel_values': pixel_values, 'labels': label}

# Apply the function to the test dataset
test_dataset = Dataset.from_dict(test_dict)
test_dataset = test_dataset.map(lambda example: load_and_preprocess_images(example, 'testing20_set/'))
test_dataset = test_dataset.remove_columns(['image_paths'])

In [None]:
# Load the pre-trained model
pretrained_model = ViTModel.from_pretrained('google/vit-base-patch16-224')

# Define your custom model
config = pretrained_model.config
config.num_labels = num_labels
model = ViTForImageClassification(config)

# Copy the pre-trained weights to your custom model
model.vit = pretrained_model

# Set the model to evaluation mode
model.load_state_dict(tensors)
model.eval()

# Create a DataLoader for the test dataset
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator)

# Lists to store predictions and true labels
all_predictions = []
all_labels = []
all_losses = []

# Iterate through the test dataset
for batch in test_dataloader:
    with torch.no_grad():
        # Forward pass
        inputs = batch['pixel_values']
        labels = batch['labels']
        outputs = model(pixel_values=inputs, labels=labels)
        logits = outputs.logits

        # Convert logits to predictions
        predictions = torch.argmax(logits, dim=1).cpu().numpy()

        loss = outputs.loss.item()

        # Append predictions and true labels to lists
        all_predictions.extend(predictions)
        all_labels.extend(labels.cpu().numpy())
        all_losses.append(loss)

# Convert lists to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_losses = np.array(all_losses)

In [None]:
# Calculate accuracy and other metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, fscore, _ = precision_recall_fscore_support(all_labels, all_predictions, average='weighted')
classification_report_str = classification_report(all_labels, all_predictions)

# Print or use the metrics as needed
print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {fscore}")
print("Classification Report:\n", tabulate([[''] + classification_report_str.split('\n')[0].split()] + [line.split() for line in classification_report_str.split('\n')[2:-5]], headers='firstrow', tablefmt='grid'))

plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(all_losses, label='Test Loss')
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(all_labels, all_predictions, 'bo', markersize=3)
plt.xlabel('True Labels')
plt.ylabel('Predicted Labels')
plt.title('True vs Predicted Labels')
plt.show()