In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms


from torchvision.models import vit_b_16
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import time

import argparse
import os
import copy
import dataloader


  from .autonotebook import tqdm as notebook_tqdm


In [8]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--annot_train_prime', type = str, default = 'df_prime_train.csv')
    parser.add_argument('--annot_test_prime', type = str, default = 'df_prime_test.csv')
    parser.add_argument('--data_root', type = str, default = '/usr/scratch/abhimanyu/courses/ECE8803_FML/OLIVES/Prime_FULL')
    parser.add_argument('--lr', type = int, default = 0.001)
    parser.add_argument('--momentum', type = int, default = 0.9)
    parser.add_argument('--epoch', type = int, default = 50)
    parser.add_argument('--batch_size', type = int, default = 5)
    parser.add_argument('--save_pth', type = str, default = '/usr/scratch/yangyu/FML_Model/vit')
  

    return parser.parse_known_args()

In [9]:

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

args, unkown = parse_args()
# Define transform
timestr = time.strftime("%Y%m%d-%H%M%S")
base_name = "restnet18_" + timestr + ".pth" 
name = os.path.join(args.save_pth, base_name)
args.save_pth = os.path.abspath(name)

# # Load dataset
# train_dataset = datasets.ImageFolder(root='path_to_train_folder', transform=transform)
# test_dataset = datasets.ImageFolder(root='path_to_test_folder', transform=transform)

# # Define dataloader
batched_trainset, batched_testset = dataloader.dataloader(args, 'ResNet') 


In [10]:
print(len(batched_trainset), len(batched_testset))

4851 1598


In [19]:
# Define model
model = vit_b_16().to(device)

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Tensorboard Writer
writer = SummaryWriter('path_to_tensorboard_logs')

# Train model
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    correct_predictions = 0.0
    total_predictions = 0.0
    start_time = time.time()

    for images, labels in tqdm(batched_trainset, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Record training loss and accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_predictions += labels.size(0)
        correct_predictions += (predicted == labels).sum().item()

    # Evaluate model on test data
    test_loss = 0.0
    test_correct_predictions = 0.0
    test_total_predictions = 0.0
    with torch.no_grad():
        for images, labels in batched_testset:
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Record test loss and accuracy
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            test_total_predictions += labels.size(0)
            test_correct_predictions += (predicted == labels).sum().item()

    # Print statistics and add to Tensorboard
    end_time = time.time()
    epoch_time = end_time - start_time
    train_loss = running_loss / len(batched_trainset)
    train_accuracy = correct_predictions / total_predictions
    test_loss /= len(batched_testset)
    test_accuracy = test_correct_predictions

2023-03-20 18:26:44.742112: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-20 18:26:49.298861: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
