In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Additional Info when using cuda
if device.type == "cuda":
    print(torch.cuda.get_device_name(0))
    print("Memory Usage:")
    print("Allocated:", round(torch.cuda.memory_allocated(0) / 1024**3, 1), "GB")
    print("Cached:   ", round(torch.cuda.memory_reserved(0) / 1024**3, 1), "GB")

# To watch nvidia-smi continuously after every 2 seconds: watch -n 2 nvidia-smi

In [None]:
BATCH_SIZE = 20
EPOCHS = 20

In [None]:
from common.data_pipeline.MMCBNU_6000.dataset import DatasetLoader as mmcbnu
from common.data_pipeline.FV_USM.dataset import DatasetLoader as fvusm
from common.util.data_pipeline.dataset_chainer import DatasetChainer
from common.util.enums import EnvironmentType

environment = EnvironmentType.PYTORCH
datasets = DatasetChainer(
    datasets=[
        mmcbnu(included_portion=1, environment_type=environment),
        fvusm(included_portion=0, environment_type=environment),
    ]
)
train, test, validation = datasets.get_dataset(environment, batch_size=BATCH_SIZE)

In [None]:
image, labels = train.dataset.data
print(image.shape, labels.shape)

In [None]:
from common.util.enums import DatasetSplitType


# datasets.get_files(DatasetSplitType.TRAIN)

In [None]:
from common.train_pipeline.isotropic_vig import isotropic_vig_ti_224_gelu

model = isotropic_vig_ti_224_gelu()
model.to(device)
print()
# print(model)

In [None]:
import torch.optim as optim
import torch.nn as nn
import torch
from tqdm import tqdm

optimizer = optim.Adam(model.parameters(), lr=0.001)
jsd = None
mixup_active = None

train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()

# Training loop
for epoch in range(EPOCHS):
    model.train()
    for inputs, labels in tqdm(train, desc=f"Epoch {epoch}: "):
        inputs = inputs.float().to(device)
        labels = labels.float().to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = train_loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in validation:
            inputs = inputs.float().to(device)
            labels = labels.float().to(device)
            outputs = model(inputs)
            val_loss += validate_loss_fn(outputs, labels)
            predicted = (outputs == outputs.max()).float()
            total += labels.size(0)

            correct += (predicted & labels).sum().item()

    print(
        f"Epoch [{epoch+1}/{EPOCHS}], Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Accuracy: {(correct/total)*100:.2f}%"
    )
model.train()

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np

image_path = "./datasets/MMCBNU_6000/ROIs/084/L_Fore/02.bmp"
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
plt.imshow(image, cmap="gray")
plt.show()
res_img = cv2.resize(image, (120, 60))
plt.imshow(res_img, cmap="gray")
plt.show()