In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
repo_path = "/content/drive/MyDrive/EmbarkLabs/MammoViT"
os.chdir(repo_path)
print("Current directory:", os.getcwd())
!ls

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
# Reformatting of image directory names to no longer contain commas
root = "/content/drive/MyDrive/EmbarkLabs/imagenet1k"

for folder in os.listdir(root):
    if ',' in folder:
        new_name = folder.replace(',', '')
        os.rename(os.path.join(root, folder), os.path.join(root, new_name))


# Load Data
This section loads the training and validation datasets using PyTorch's `DataLoader`.

In [None]:
from pathlib import Path
from src.train_fine_tune.pretrain_vit import create_dataloaders

# Define data directory and batch size
data_dir = Path('/content/drive/MyDrive/EmbarkLabs/imagenet1k')
batch_size = 64

# Instantiate Model and Start Training
This section initializes the Vision Transformer (ViT) model, sets up the optimizer and loss function, and starts the training process.

In [None]:
from src.train_fine_tune.pretrain_vit import train_vit_model
from pathlib import Path

data_dir = Path('/content/drive/MyDrive/EmbarkLabs/imagenet1k')
log_dir = Path('/content/drive/MyDrive/EmbarkLabs/MammoViT/logs')

epochs = 10
batch_size = 64
learning_rate = 1e-4
save_path = "logs/preTrainedViT"

train_vit_model(data_dir, log_dir, epochs, batch_size, learning_rate)


# Pull Metrics and Create Figures
This section reads the logged metrics and generates visualizations for loss, accuracy, and confusion matrix.

In [None]:
from src.eval_and_metrics.figures import plot_loss_curve, plot_accuracy_curve, plot_confusion_matrix

# Define paths
metrics_path = log_dir / 'metrics.csv'
output_path = log_dir / 'figures'
output_path.mkdir(parents=True, exist_ok=True)

# Generate figures
plot_loss_curve(metrics_path, output_path)
plot_accuracy_curve(metrics_path, output_path)

# Generate confusion matrix
cm_path = log_dir / 'confusion_matrix.json'
class_names = ['class1', 'class2', 'class3', 'class4']
plot_confusion_matrix(cm_path, output_path, class_names)