## Import Libraries

In [1]:
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from torchvision.transforms import functional as TF
from skimage.metrics import structural_similarity as ssim
import tifffile 

## Enable Hardware

In [2]:
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print('Device available:', device)

Device available: mps


## Create Dataset

### Data Path

In [None]:
# Define model data path
lc_model_data_path = "data/model_data/" # LC only
all_model_data_path = "data/all_model_data/" # all data
brachial_plexus_data = "/Users/mihirjoshi/Library/CloudStorage/OneDrive-TheOhioStateUniversity(2)/Honors Thesis/Ocular US Denoising/brachial_plexus_kaggle_data/ultrasound-nerve-segmentation/"

# Update these parameters
data_path = brachial_plexus_data

train_label =  'train/' #'Train/'
val_label = 'test/' #'Validation/'

In [None]:
class EyeDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.directory = directory
        self.transform = transform
        self.filenames = [f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg', '.tif', '.bmp'))]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.directory, self.filenames[idx])
        image = Image.open(img_path)
        if self.transform: # Dynamically apply data transformation
            image = self.transform(image)
        return image

# Create a transform to convert the images to PyTorch tensors
transform = transforms.Compose([
    transforms.ToTensor()
])

# Create the dataset for images
train_data = EyeDataset(data_path + train_label, transform=transform)
val_data = EyeDataset(data_path + val_label, transform=transform)

# Function to create data loader
def create_loader(train_dataset, batch_size):
    torch.manual_seed(0)  # For reproducibility of random numbers in PyTorch
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # Creates a training DataLoader from this Dataset

    return train_loader

train_dataset_size = len(train_data)
print('Number of images in the training dataset:', train_dataset_size)

val_dataset_size = len(val_data)
print('Number of images in the validation dataset:', val_dataset_size)