## Imports

In [1]:
import os
import torch as T
import torchvision as TV
from math import prod
from tqdm.auto import tqdm
from datetime import datetime
from torch.utils.data import DataLoader
from FruitIST import FruitIST, kModelV1, kTrainerV1
from safetensors.torch import load_model, save_model

%matplotlib inline

## Hyperparameters

In [2]:
BATCH_SIZE      = 64
NUM_EPOCHS      = 100
LEARNING_RATE   = 1e-4
LEARNING_STEP   = 5
LEARNING_GAMMA  = 1e-1

TIME_FORMAT     = "%H:%M:%S"
CHECKPOINT_PATH = "fruitist_res_state.pt"

## Check for CUDA

In [3]:
device = T.device("cuda" if T.cuda.is_available() else "cpu")
device

device(type='cuda')

## Load Data

#### Create transform

In [4]:
transform = TV.transforms.ConvertImageDtype(T.float32)

#### Load data sets

In [5]:
training_set = FruitIST(
    root='.',
    train=True,
    transform=transform
)

testing_set = FruitIST(
    root='.',
    train=False,
    transform=transform
)

training_set, testing_set

(Dataset FruitIST
     Number of datapoints: 67692
     Root location: .
     StandardTransform
 Transform: ConvertImageDtype(),
 Dataset FruitIST
     Number of datapoints: 22688
     Root location: .
     StandardTransform
 Transform: ConvertImageDtype())

#### Create data loaders

In [6]:
train_loader = DataLoader(
    dataset=training_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
    prefetch_factor=2
)

test_loader = DataLoader(
    dataset=testing_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=2,
    prefetch_factor=2
)

#### Get data information

In [7]:
num_channels = training_set[0][0].shape[0]
num_classes  = len(training_set.classes)
image_size   = prod(training_set[0][0].shape[1:])

print(f"{num_channels = }")
print(f"{num_classes  = }")
print(f"{image_size   = }")

num_channels = 3
num_classes  = 131
image_size   = 10000


## Model

#### Create model

In [8]:
model = TV.models.resnet18(num_classes=num_classes).to(device)

#### Create utilites

In [9]:
criterion = T.nn.CrossEntropyLoss()
optimizer = T.optim.Adam(params=model.parameters(),
                         lr=LEARNING_RATE)
scheduler = T.optim.lr_scheduler.StepLR(optimizer=optimizer,
                                        gamma=LEARNING_GAMMA,
                                        step_size=LEARNING_STEP)

#### Create trainer

In [10]:
trainer = kTrainerV1(model=model,
                     loss_function=criterion,
                     optimizer=optimizer,
                     scheduler=scheduler)

#### Check for saved state

In [11]:
best = float("-INF")

if os.path.isfile(CHECKPOINT_PATH):
    print("Saved state detected. Loading file . . . ")

    try:
        load_model(model, CHECKPOINT_PATH)
    except:
        print(f"Unable to load saved state. Starting over")
    else:
        best     = trainer.test(train_loader,   progress_bar=True)
        accuracy = trainer.predict(test_loader, progress_bar=True)

        print(f"Loaded saved state with a loss of {best:.2e} accuracy of {accuracy:.2%}")
else:
    print("No saved state file detected")

Saved state detected. Loading file . . . 


Testing:   0%|          | 0/1058 [00:03<?, ?Batches/s]

Testing:   0%|          | 0/355 [00:00<?, ?Batches/s]

Loaded saved state with a loss of 2.50e-05 accuracy of 98.88%


## Learn

#### Run forward pass

In [None]:
print(f"[{ datetime.now().strftime(TIME_FORMAT) }] Starting Training . . . ")

for epoch in tqdm(range(NUM_EPOCHS), desc="Training", unit="Epoch"):
    saved    = ""
    time     = datetime.now()
    loss     = trainer.forward(train_loader, progress_bar=True)
    accuracy = trainer.predict(test_loader,  progress_bar=True)
    
    if accuracy >= best:
        best  = accuracy
        saved = " | Saved"

        save_model(model, CHECKPOINT_PATH)

    print(f"[{datetime.now().strftime(TIME_FORMAT)}] (+{(datetime.now() - time).total_seconds():.0f}s) Epoch: {epoch +1}/{NUM_EPOCHS} | Loss = {loss:.2e} | Accuracy = {accuracy:.2%}{saved}")

[06:11:58] Starting Training . . . 


Training:   0%|          | 0/100 [00:00<?, ?Epoch/s]

Training:   0%|          | 0/1058 [00:00<?, ?Batches/s]