In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image

from tqdm import tqdm

import tensorflow as tf

from transformers import ViTFeatureExtractor, TFAutoModelForImageClassification
from transformers import ViTForImageClassification, ViTConfig, ViTImageProcessor

from sklearn.model_selection import train_test_split


import torch
import torch.optim as optim
import torch.nn as nn
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

DATA_DIR  = 'mimic-data/'
ROOT = os.path.dirname(os.getcwd())
os.chdir(f'E:/RadioCareBorealisAI')

from data_modules.mimic_cxr import MimicIVCXR

def seed_everything(seed: int) -> None:
    """ Seed everything for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False

In [None]:
class args:
    seed = 23
    train_batch_size = 48
    valid_batch_size = 16
    test_batch_size = 16
    num_labels = 2
    num_epochs = 5

# Set your device
seed_everything(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
graph_report_dir = f"{DATA_DIR}/graph_report.csv"
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')

dataset = MimicIVCXR(data_root=DATA_DIR,
                     graph_report_dir=graph_report_dir,
                     tokenizer="AutoTokenizer.from_pretrained('bert-base-uncased')",
                     max_length=3000,
                     transform=processor)


In [None]:
class LimitedDataset:
    def __init__(self, full_dataset, limit=20000):
        self.full_dataset = full_dataset
        self.limit = min(limit, len(full_dataset))

    def __getitem__(self, index):
        if index < self.limit:
            return self.full_dataset[index]
        else:
            raise IndexError("Index out of range")

    def __len__(self):
        return self.limit

# Example usage:
dataset = LimitedDataset(dataset)

# Now you can use limited_dataset with indexing up to 20,000
print(len(dataset))  # Should print 20000
print(dataset[19999])  # Accessing the last item within the limit

In [None]:
# train_data, val_test_data = train_test_split(dataset, test_size=0.2)
# val_data, test_data = train_test_split(val_test_data, test_size=0.5)

In [None]:
# Load your data and create dataloaders
train_dataloader = DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False)
# val_dataloader = DataLoader(val_data, batch_size=args.valid_batch_size, shuffle=False)
# test_dataloader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False)

In [None]:
config = ViTConfig.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=args.num_labels)
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k",
                                                  config=config)
model.to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-5)

train_losses = []
val_accuracies = []

for epoch in range(args.num_epochs):

    model.train()
    running_loss = 0.0
    for images, _, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{args.num_epochs}", unit="batch", total=len(train_dataloader)):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)

    avg_train_loss = running_loss / len(train_dataloader)
    train_losses.append(avg_train_loss)

    # Validation phase
    model.eval()
    correct = 0
    total = 0
    # with torch.no_grad():
    #     for images, _, labels in val_dataloader:
    #         images, labels = images.to(device), labels.to(device)
    #         outputs = model(images).logits
    #         _, predicted = outputs.max(1)
    #         total += labels.size(0)
    #         correct += (predicted == labels).sum().item()
    
    val_accuracy = 100 * correct / total
    val_accuracies.append(val_accuracy)
    
    print(f"Epoch {epoch+1}/{args.num_epochs}, Loss: {avg_train_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

In [None]:
running_loss / len(train_dataloader)

In [None]:
correct = 0
total = 0

with torch.no_grad():
    model.eval()

    for images, text, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

test_accuracy = 100 * correct / total
print(f"Test Accuracy: {test_accuracy:.2f}%")