In [None]:
# Install AK_SSL

!pip install AK_SSL

In [None]:
# import libraries

from AK_SSL import Trainer
import torch
import torchvision

In [None]:
# load pretext dataset

train_unlabeled_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split='unlabeled',
    transform=torchvision.transforms.ToTensor(),
    download=True,
)


In [None]:
# define backbone and remove the last layer

backbone = torchvision.models.resnet18(weights=None)
feature_size = backbone.fc.in_features
backbone.fc = torch.nn.Identity()

In [None]:
# define Trainer

trainer = Trainer(
    method="barlowtwins",
    backbone=backbone,
    feature_size=feature_size,
    dataset=train_unlabeled_dataset,
    image_size=96,
    save_dir="./save_for_report/",
    checkpoint_interval=50,
    reload_checkpoint=False,
)

In [None]:
# train

trainer.train(
    batch_size=256,
    start_epoch=1,
    epochs=500,
    optimizer="Adam",
    weight_decay=1e-6,
    learning_rate=1e-3,
)

In [None]:
# load evaluate dataset

train_label_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split='train',
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_dataset = torchvision.datasets.STL10(
    root="../datasets/" + "stl10",
    split='test',
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [None]:
# evaluate

trainer.evaluate(
    train_dataset=train_label_dataset,      
    test_dataset=test_dataset,        
    eval_method="linear",             
    top_k=1,                          
    epochs=100,                       
    optimizer='Adam',                 
    weight_decay=1e-6,                
    learning_rate=1e-3,               
    batch_size=256,                   
    fine_tuning_data_proportion=1,   
)