In [2]:
import torch
import timm
import wandb
from tqdm.auto import tqdm
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor, Compose, Resize, Normalize, CenterCrop
from torch.utils.data import DataLoader


RESIZED_SIZE = 256
TRAIN_ROOT = 'd:/Data/PTX/train'
VAL_ROOT = 'd:/Data/PTX/validation'
BATCH_SIZE = 32
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transforms = Compose([
    ToTensor(),
    Resize(size=(RESIZED_SIZE, RESIZED_SIZE)),
    CenterCrop(size=(RESIZED_SIZE, RESIZED_SIZE))
])

train_dataset = ImageFolder(root=TRAIN_ROOT, transform=transforms)
val_dataset = ImageFolder(root=VAL_ROOT, transform=transforms)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def train(model, train_dataloader, val_dataloader, loss_fn, optimizer, epochs=10, device='cpu', log_experiment=True):
    if log_experiment:
        val_loss_history = []
        val_acc_history = []
        artifact = wandb.Artifact('weights', type='model')
        
        run = wandb.init(
            project="PTX",
            config={
                "model": model.__class__.__name__,
                "epochs": epochs,
            })
    
    for epoch in tqdm(range(epochs)):
        print(f'Epoch {epoch+1}\n-------------------------------')
        train_loss, train_acc = train_epoch(model, train_dataloader, loss_fn, optimizer, device)
        print(f'train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}')
        val_loss, val_acc = test_epoch(model, val_dataloader, loss_fn, device)
        print(f'val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}')
        if log_experiment:
            val_loss_history.append(val_loss)
            val_acc_history.append(val_acc)

            if val_acc == max(val_acc_history) and val_loss == min(val_loss_history):
                torch.save(model.state_dict(), 'best.pt')            

            wandb.log({'train_loss': train_loss, 'train_acc': train_acc, 'val_loss': val_loss, 'val_acc': val_acc})
    
    
    torch.save(model.state_dict(), 'last.pt')

    if log_experiment:
        artifact.add_file('last.pt')
        artifact.add_file('best.pt')
        artifact.save()
        run.log_artifact(artifact)
        run.finish()


def train_epoch(model, dataloader, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    model.train()
    train_loss, correct = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    train_loss /= size
    correct /= size
    return train_loss, correct

def test_epoch(model, dataloader, loss_fn, device):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    return test_loss, correct

def train_all_models(models, train_dataloader, val_dataloader, loss_fn, lr, epochs=10, device='cpu', log_experiment=True):
    print(f'Starting ensemble model training for {epochs} epochs')
    for model in models:
        print(f'Training {model.__class__.__name__}')
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        train(model=model.to(device),
              train_dataloader=train_dataloader,
              val_dataloader=val_dataloader,
              loss_fn=loss_fn,
              optimizer=optimizer,
              epochs=epochs,
              device=device,
              log_experiment=log_experiment)

In [11]:
#Validates a model on val_dataloader, and displays the results
def validate(model, val_dataloader, device='cpu'):
    loss_fn = torch.nn.CrossEntropyLoss()
    loss, acc = test_epoch(model, val_dataloader, loss_fn, device)
    print(f'loss: {loss:.4f}, acc: {acc:.4f}')
    return loss, acc

In [7]:
timm_resnet = timm.create_model('resnet50', pretrained=True, num_classes=3, drop_rate=0.5)
timm_inception_v3 = timm.create_model('inception_v3', pretrained=True, num_classes=3, drop_rate=0.5)
timm_inception_resnet = timm.create_model('inception_resnet_v2', pretrained=True, num_classes=3, drop_rate=0.5)
timm_xception = timm.create_model('xception', pretrained=True, num_classes=3, drop_rate=0.5)
timm_VGG16 = timm.create_model('vgg16', pretrained=True, num_classes=3, drop_rate=0.5)

timm_models = [timm_resnet, timm_inception_v3, timm_inception_resnet, timm_xception, timm_VGG16]

In [None]:
epochs = 15
lr = 1e-4
loss_fn = torch.nn.CrossEntropyLoss()
log = True

train_all_models(models=timm_models,
                 train_dataloader=train_dataloader,
                 val_dataloader=val_dataloader,
                 loss_fn=torch.nn.CrossEntropyLoss(),
                 lr=lr,
                 epochs=epochs,
                 device=DEVICE,
                 log_experiment=log)

In [27]:
epochs = 2
lr = 1e-4
loss_fn = torch.nn.CrossEntropyLoss()
model = timm_resnet.to(DEVICE)
log = False


train(model=model,
      train_dataloader=train_dataloader,
      val_dataloader=val_dataloader,
      loss_fn=loss_fn,
      optimizer=torch.optim.Adam(model.parameters(), lr=lr),
      epochs=epochs,
      device=DEVICE,
      log_experiment=log)

  0%|          | 0/2 [00:00<?, ?it/s]



Epoch 1
-------------------------------


  0%|          | 0/2 [00:15<?, ?it/s]


KeyboardInterrupt: 