In [None]:
from CAP.data import download_food101,create_dataloader,save_result
from CAP.engine import train,eval_model
from CAP.model import Cap,getROIS

import torchvision.transforms as transforms
import torch.nn as nn
from torch.optim import Adam,SGD
import os
import torch
import torch.optim.lr_scheduler as lr_scheduler

In [None]:
if os.path.isdir('data/food-101'):
    print("food101 dataset downloaded.")
else:
    food_101_train_path,food_101_test_path=download_food101(root='data',val_split=False)


In [None]:
#define train and test transformer
train_transform=transforms.Compose([
    transforms.Resize((256, 256),interpolation=transforms.InterpolationMode.BICUBIC),  # Resize to 256x256
    transforms.RandomRotation(degrees=15),  # Random rotation ±15 degrees
    transforms.RandomResizedCrop(size=224, scale=(0.85, 1.15)),  # Random crop to 224 with scaling
    transforms.ToTensor(),  # Convert image to tensor
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_transform=transforms.Compose([
    transforms.Resize((224,224),interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
train_dataloader,test_dataloader,class_names=create_dataloader(train_path=food_101_train_path,
                                                                              test_path=food_101_test_path,
                                                                              train_transform=train_transform,
                                                                              test_transform=test_transform,
                                                                              batch_size=96,
                                                                              num_workers=os.cpu_count())


In [None]:
rois_mat,num_rois =  getROIS(resolution=42,gridSize=3, minSize=2)


In [None]:
model=Cap(channels=1280,pool_size=7,num_rois=26,rois_mat=rois_mat,feature_dim=1280*7*7,hidden_size=128,cluster_size=32,out_dim=101)
loss_fn=nn.CrossEntropyLoss()
optimizer=SGD(params=model.parameters(),lr=0.0001,momentum=0.99)
scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device=device)

In [None]:
save_path='CAP_food101'
results=train(model=model,
                     train_dataloader=train_dataloader,
                     test_dataloader=test_dataloader,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     loss_fn=loss_fn,
                     epochs=150,
                     device=device,
                     model_name="CAP-food101.pth",
                     model_saving_dir=save_path)


In [None]:
test_results=eval_model(model=model,
                               dataloader=test_dataloader,
                               loss_fn=loss_fn,
                               num_classes=len(class_names),
                               device=device)


In [None]:
save_result(data_dict=results,
            result_save_dir=save_path,
            result_name="CAP-food101_train_results")
save_result(data_dict=test_results,
            result_save_dir=save_path,
            result_name="CAP-food101_eval_results")