## Test Framework for all Models
#### Author: Ayush Tripathi (atripathi7783@gmail.com)

In [1]:
#imports
import os
import time
import pydicom
from tqdm import tqdm
import pandas as pd
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import logging
import sys
from pathlib import Path


project_root = Path(os.getcwd()).resolve().parents[1]
sys.path.append(str(project_root))
print(project_root) #should be base file

C:\Users\atrip\Classes\ECS-174-Project


### Dataloading

In [None]:
'''
Goal: Load and store dataloader as an object for use throughout future models. 
'''

#import the class from the project root
from src.etl.data_loading import LumbarSpineDataset

manual_seed = 110
torch.manual_seed(manual_seed)
print(f"manual seed: {manual_seed}")
# Initialize the dataset
image_dir = r"C:\Users\atrip\Classes\ECS-174-Project\src\dataset\rsna-2024-lumbar-spine-degenerative-classification\train_images"
metadata_dir = r"C:\Users\atrip\Classes\ECS-174-Project\src\dataset\rsna-2024-lumbar-spine-degenerative-classification" 
transform = transforms.Compose([  
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))])

dataset = LumbarSpineDataset(image_dir=image_dir, metadata_dir=metadata_dir, transform=transform, load_fraction=1)

# Create DataLoader with tqdm for progress bar
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

manual seed: 110


Loading images:   1%|          | 444/48692 [00:01<02:51, 280.89it/s]


KeyboardInterrupt: 

### Load Model 

In [None]:
from src.arch.unet import *  #change depending on model used
# from src.arch.cnn import *

#define params
NUM_INPUT_CHANNELS = 3  
NUM_OUTPUT_CLASSES = 10

#initialize
model = UNet(NUM_INPUT_CHANNELS, NUM_OUTPUT_CLASSES)

#test with random shape
example_input = torch.randn(1, NUM_INPUT_CHANNELS, 512, 512)  # Batch size 1, RGB image of 224x224
output = model(example_input)
print(output)


tensor([[[[-0.2766, -0.3923,  0.4404,  ..., -0.1394, -0.1185, -0.2920],
          [-0.1991, -0.1440, -0.0740,  ..., -0.2165,  0.1650, -0.3053],
          [ 0.2327, -0.3668,  0.4019,  ...,  0.3188, -0.0732, -0.0957],
          ...,
          [ 0.0899,  0.2020,  0.1553,  ..., -0.2785, -0.3658, -0.3978],
          [-0.0140, -0.3717,  0.3657,  ..., -0.2634, -0.3188, -0.0986],
          [-0.1778, -0.1772,  0.0589,  ...,  0.0316, -0.2206, -0.3987]],

         [[-0.1312, -0.4847, -0.2530,  ..., -0.0217, -0.3034, -0.1832],
          [-0.3285, -0.6494, -0.0032,  ..., -0.2336, -0.7965, -0.3267],
          [-0.5735,  0.0722, -0.0109,  ..., -0.4004, -0.6150, -0.0697],
          ...,
          [-0.4727, -0.5641, -0.3501,  ..., -0.6341, -0.3996, -0.5544],
          [-0.8073, -0.2831, -0.4791,  ..., -0.2603, -0.4110,  0.1284],
          [-0.4233, -0.1227, -0.5186,  ..., -0.1678, -0.1815, -0.2103]],

         [[ 0.0806,  0.3720, -0.2275,  ...,  0.1129,  0.1748,  0.0709],
          [ 0.3648, -0.0810,  