In [1]:
from few_shot import *
from torchvision import datasets
from src.transformations import *
from torchvision.models import resnet18, ResNet18_Weights
from easyfsl.utils import plot_images

In [2]:
train_set = datasets.ImageFolder(root="./data/train", transform=resize_transform())
test_set = datasets.ImageFolder(root="./data/test", transform=resize_transform())
validation_test = datasets.ImageFolder(
    root="./data/valid", transform=resize_transform()
)

In [3]:
resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
resnet.fc = nn.Flatten()
model = PrototypicalNetwork(resnet)

In [4]:
N_WAY = 5  # Number of classes in a task
N_SHOT = 5  # Number of images per class in the support set
N_QUERY = 10  # Number of images per class in the query set
N_EVALUATION_TASKS = 10
N_TRAINING_EPISODES = 5
N_VALIDATION_TASKS = 5

In [5]:
test_loader = get_few_shot_dataloader(
    test_set, N_WAY, N_SHOT, N_QUERY, N_EVALUATION_TASKS
)
train_loader = get_few_shot_dataloader(
    train_set, N_WAY, N_SHOT, N_QUERY, N_TRAINING_EPISODES
)
validation_loader = get_few_shot_dataloader(
    validation_test, N_WAY, N_SHOT, N_QUERY, N_VALIDATION_TASKS
)

In [8]:
fewshot_trainer = FewShotTrainer(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    val_loader=validation_loader,
)

Test Evaluation

In [9]:
(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
    example_class_ids,
) = next(iter(fewshot_trainer.test_loader))

In [10]:
fewshot_trainer.evaluate_on_one_task(
    example_support_images,
    example_support_labels,
    example_query_images,
    example_query_labels,
)

(28, 50)

In [11]:
fewshot_trainer.evaluate()

100%|██████████| 10/10 [00:20<00:00,  2.01s/it]

Model tested on 10 tasks. Accuracy: 49.80%





Test Training

In [12]:
fewshot_trainer.train()
fewshot_trainer.evaluate()

100%|██████████| 5/5 [00:23<00:00,  4.79s/it, loss=0.984]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


Validation Accuracy: 47.60%


100%|██████████| 10/10 [00:18<00:00,  1.89s/it]

Model tested on 10 tasks. Accuracy: 50.40%





## Other models

### VGG16 pretrained

In [None]:
from torchvision.models import vgg16, VGG16_Weights

vgg = vgg16(weights=VGG16_Weights.DEFAULT)
vgg.classifier = nn.Flatten()
model2 = PrototypicalNetwork(vgg)

In [14]:
fewshot_trainer2 = FewShotTrainer(
    model=model2,
    train_loader=train_loader,
    test_loader=test_loader,
    val_loader=validation_loader,
)

fewshot_trainer2.train()
fewshot_trainer2.evaluate()

100%|██████████| 5/5 [02:34<00:00, 30.98s/it, loss=1.53]
100%|██████████| 5/5 [01:01<00:00, 12.22s/it]


Validation Accuracy: 38.80%


100%|██████████| 10/10 [02:04<00:00, 12.45s/it]

Model tested on 10 tasks. Accuracy: 37.80%





### VGG16 custom

In [15]:
from src.models.vgg_custom import *

In [16]:
vgg_custom = VGG16Custom()
vgg_custom.classifier = nn.Flatten()
model3 = PrototypicalNetwork(vgg_custom)

In [17]:
fewshot_trainer3 = FewShotTrainer(
    model=model3,
    train_loader=train_loader,
    test_loader=test_loader,
    val_loader=validation_loader,
)

fewshot_trainer3.train()
fewshot_trainer3.evaluate()

100%|██████████| 5/5 [02:19<00:00, 27.88s/it, loss=1.61]
100%|██████████| 5/5 [00:52<00:00, 10.55s/it]


Validation Accuracy: 20.80%


100%|██████████| 10/10 [01:51<00:00, 11.10s/it]

Model tested on 10 tasks. Accuracy: 22.20%



