# Imports

In [None]:
from matplotlib import image as mpimg
import tensorflow as tf
import torch
import os
import requests
from PIL import Image
from torchvision.transforms import ToTensor
from tensorflow import keras
import matplotlib.pyplot as plt     # to plot charts
import numpy as np
import pandas as pd                 # for data manipulation
import cv2                          # for image processing
from io import BytesIO
from tabulate import tabulate       # to print pretty tables
import seaborn as sns
import shutil

# sklearn imports for metrics and dataset splitting
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# keras imports for image preprocessing
from keras.preprocessing.image import ImageDataGenerator

# huggingface imports for model building 
import torch.nn as nn
from transformers import ViTModel, ViTForImageClassification, TrainingArguments, Trainer, \
  default_data_collator, EarlyStoppingCallback, ViTConfig, AutoImageProcessor, ViTImageProcessor 
from transformers.modeling_outputs import SequenceClassifierOutput

# keras imports for early stoppage and model checkpointing
from torchvision.transforms import ToTensor, Resize
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image

from datasets import load_dataset, load_metric, Features, ClassLabel, Array3D, Dataset
import datasets

from imblearn.over_sampling import RandomOverSampler
from imblearn.under_sampling import RandomUnderSampler

# Import (Top 20)

In [None]:
# load the csv files
csv_file_top20 = "./top20.csv"
csv_file = "./table.csv"
csv_file2 = "./directory_consumer_grade_images.xlsx"
top20_df = pd.read_csv(csv_file_top20)
table_df = pd.read_csv(csv_file)
directory_df = pd.read_excel(csv_file2)

top20_list = top20_df['Name'].tolist()

In [None]:
# Create a function to get the base label
def get_base_label(label):
    for item in top20_list:
        if item.lower() in label.lower():
            return item
    return label

In [None]:
# find the top 20 medications in the two datasets
# find matches in table_df
matches_in_table_df = pd.DataFrame()
for item in top20_list:
    matches = table_df[table_df['name'].str.contains(item, case=False, na=False) & 
                       ~table_df['name'].str.contains('and|/', case=False, na=False)]           # remove rows with 'and' or '/' in the name
    matches_in_table_df = pd.concat([matches_in_table_df, matches])

# find matches in directory_df
matches_in_directory_df = pd.DataFrame()
for item in top20_list:
    matches = directory_df[directory_df['Name'].str.contains(item, case=False, na=False) & 
                           ~directory_df['Name'].str.contains('and|/', case=False, na=False)]    # remove rows with 'and' or '/' in the name
    matches_in_directory_df = pd.concat([matches_in_directory_df, matches])

# generate the test set
test_df = matches_in_directory_df[matches_in_directory_df['Layout'] == 'C3PI_Test']

# keep only necessary images
matches_in_directory_df = matches_in_directory_df[matches_in_directory_df['Layout'].isin(['MC_API_NLMIMAGE_V1.3', 'MC_CHALLENGE_V1.0'])]

# remove unnecessary columns and rename columns
matches_in_table_df = matches_in_table_df[['name', 'nlmImageFileName']]
matches_in_table_df = matches_in_table_df.rename(columns={'name': 'labels', 'nlmImageFileName': 'image_paths'})
matches_in_directory_df = matches_in_directory_df[['Image', 'Name']]
matches_in_directory_df = matches_in_directory_df.rename(columns={'Image': 'image_paths', 'Name': 'labels'})
test_df = test_df[['Image', 'Name']]
test_df = test_df.rename(columns={'Image': 'image_paths', 'Name': 'labels'})

# add a base label column for the top 20 medications
matches_in_table_df['base_label'] = matches_in_table_df['labels'].apply(get_base_label)
matches_in_directory_df['base_label'] = matches_in_directory_df['labels'].apply(get_base_label)
test_df['base_label'] = test_df['labels'].apply(get_base_label)

# encode labels
encoder = LabelEncoder()
matches_in_table_df['labels'] = encoder.fit_transform(matches_in_table_df['labels'])
matches_in_directory_df['labels'] = encoder.fit_transform(matches_in_directory_df['labels'])
test_df['labels'] = encoder.fit_transform(test_df['labels'])

top20_instances_df = pd.concat([matches_in_table_df, matches_in_directory_df])

In [None]:
print('training set size: ',top20_instances_df.size)
print('test set size: ',test_df.size)

In [None]:
top20_instances_df.head()

In [None]:
# Check if the data is imbalanced in the training set
label_counts = top20_instances_df['base_label'].value_counts()
print(label_counts)

# Plot the label counts
plt.figure(figsize=(10,6))
plt.bar(label_counts.index, label_counts.values, alpha=0.5, color='g')
plt.title('Distribution of Base Labels (Training Set)')
plt.xlabel('Base Label')
plt.ylabel('Number of Labels')
plt.xticks(rotation=90)
plt.grid(True)
plt.show()

# Check if the data is imbalanced in the test set
test_label_counts = test_df['base_label'].value_counts()
print(test_label_counts)

# Plot the label counts for the test set
plt.figure(figsize=(10,6))
plt.bar(test_label_counts.index, test_label_counts.values, alpha=0.5, color='b')
plt.title('Distribution of Base Labels (Test Set)')
plt.xlabel('Base Label')
plt.ylabel('Number of Labels')
plt.xticks(rotation=90)
plt.grid(True)
plt.show()

# Downloading the Training Data

In [None]:
website_url = 'https://data.lhncbc.nlm.nih.gov/public/Pills/'
dataset_dir = './dataset'
training_dir = './training20_set'

# Make sure the training directory exists
if not os.path.exists(training_dir):
    os.makedirs(training_dir)

# Function to download an image from a URL and save it to a directory
def download_image(url, save_path):
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(save_path, 'wb') as f:
            response.raw.decode_content = True
            shutil.copyfileobj(response.raw, f)
        return True
    else:
        print(f"Failed to download image from {url}")
        return False

for index, row in top20_instances_df.iterrows():
    file_name = row['image_paths']
    if os.path.exists(os.path.join(dataset_dir, file_name)):
        shutil.copy(os.path.join(dataset_dir, file_name), os.path.join(training_dir, os.path.basename(file_name)))
    else:
        url = website_url + file_name
        save_path = os.path.join(training_dir, os.path.basename(file_name))
        if download_image(url, save_path):
            print(f"Downloaded {file_name} from {url}")
        else:
            print(f"Failed to find {file_name} in dataset_dir and download from {url}")


# Downloading the Test Data

In [None]:
testing_dir = './testing20_set'

# Make sure the testing directory exists
if not os.path.exists(testing_dir):
    os.makedirs(testing_dir)

for index, row in test_df.iterrows():
    file_name = row['image_paths']
    
    # Check if the file ends with ".wmv", if so, skip it
    if file_name.endswith('.WMV'):
        print(f"Skipping {file_name} as it has the .wmv extension")
        continue
    
    if os.path.exists(os.path.join(dataset_dir, file_name)):
        shutil.copy(os.path.join(dataset_dir, file_name), os.path.join(testing_dir, os.path.basename(file_name)))
    else:
        url = website_url + file_name
        save_path = os.path.join(testing_dir, os.path.basename(file_name))
        if download_image(url, save_path):
            print(f"Downloaded {file_name} from {url}")
        else:
            print(f"Failed to find {file_name} in dataset_dir and download from {url}")


In [None]:
# print the size of the dataset
print('Number of files in the training set: ', len(os.listdir('./training20_set')))
print('Number of files in the test set: ', len(os.listdir('./testing20_set')))