In [5]:
import yaml
import torch
from datasets import get_dataset
from utils.get_transforms import get_transforms
from utils.dataloader import get_dataloaders
from models.model import get_model
from utils.train_utils import train, evaluate, save_checkpoint


In [10]:
with open('configs/config.yaml','r') as f:
    cfg = yaml.safe_load(f)

datasets_cfg = cfg['dataset']
models_cfg = cfg['model']
train_cfg = cfg['train']
optim_cfg = cfg['optim']
checkpoint_cfg = cfg['checkpoint']
output_cfg = cfg['output']

In [12]:
train_transform , test_transform = get_transforms()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_dataset = get_dataset(datasets_cfg['name'],root=datasets_cfg['root'],split='train',transform=train_transform)
test_dataset = get_dataset(datasets_cfg['name'],root=datasets_cfg['root'],split='test',transform=train_transform)
dataset_dict = {'train':train_dataset,'tests':test_dataset}

dataloaders = get_dataloaders(dataset_dict,batch_size=datasets_cfg['batch_size'],num_workers=datasets_cfg['num_workers'],shuffle_train=True)

In [14]:
params = models_cfg.get('params',{})
model = get_model(models_cfg['name'],models_cfg['num_classes'],params=params).to(device)

Downloading: "https://download.pytorch.org/models/swin_t-704ceda3.pth" to C:\Users\Milad/.cache\torch\hub\checkpoints\swin_t-704ceda3.pth


100%|██████████| 108M/108M [00:26<00:00, 4.21MB/s] 


In [15]:
optimizer = torch.optim.Adam(model.parameters(),optim_cfg['lr'])
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
best_acc = 0
epochs = train_cfg['epochs']

for epoch in range(epochs):
    train_loss = train(model,dataloaders['train'],optimizer,loss_fn,device)
    acc = evaluate(model,dataloaders['test'],device)
    print(f"epoch:{epoch+1}/{epochs}: train loss: {train_loss:.4f} , acuracy: {acc:.2f}%")
    if acc > best_acc:
        save_checkpoint(model,models_cfg['name'])
        
print("training_complete")