In [None]:
from vision_transformers.models import vit

def build_model(num_classes=10):
    model = vit.vit_b_p16_224(
        image_size=224,
        num_classes=num_classes,
        pretrained=True
    )
    return model

# if __name__ == '__main__':
model = build_model()
print(model)
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} total trainable parameters.')


In [5]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

TRAIN_DIR = os.path.join('./data/Cotton-Disease-Training/trainning/Cotton leaves - Training/800 Images')
VALID_DIR = os.path.join('./data/Cotton-Disease-Training/validation/Cotton leaves - Validation/100 Images')
IMAGE_SIZE = 224
NUM_WORKERS = 4

In [6]:
def get_train_transform(image_size):
    train_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(35),
        transforms.RandomAdjustSharpness(sharpness_factor=2, p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    ])
    return train_transform

def get_valid_transform(image_size):
    valid_transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225]),
    ])
    return valid_transform

In [None]:
def get_datasets():
    train_transform = get_train_transform(IMAGE_SIZE)
    valid_transform = get_valid_transform(IMAGE_SIZE)

    train_dataset = datasets.ImageFolder(
        root=TRAIN_DIR,
        transform=train_transform
    )

    valid_dataset = datasets.ImageFolder(
        root=VALID_DIR,
        transform=valid_transform
    )

    return train_dataset, valid_dataset

def get_data_loader(dataset_train, dataset_valid, batch_size):
    train_loader = DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS
    )

    valid_loader = DataLoader(
        dataset_valid,
        batch_size=batch_size,
        shuffle=False,
        num_workers=NUM_WORKERS
    )

    return train_loader, valid_loader