In [None]:
from src.models.our_model import OurModel
from src.models.pretrained_models import VGG16Pretrained, ResNetPretrained
import torch
from src.models.ensemble import HardVotingEnsemble, SoftVotingEnsemble, MetaClassifier, StackingEnsemble
from src.utils import load_data, evaluate_model
from src.model_trainer import ModelTrainer
from src.transformations import normalized_simple_transform
device_str = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device_str)
torch.manual_seed(123)
print(device)
torch.set_num_threads(14)

In [None]:
  
model1 = OurModel()  
model1.load_state_dict(torch.load(".\saved_models\ourmodel\combined_20\OurModel_16.pth", map_location=device))
model1.to(device)

model2 = OurModel()  
model2.load_state_dict(torch.load(".\saved_models\ourmodel\combined_20\OurModel_16.pth", map_location=device))
model2.to(device)

model3 = OurModel()  
model3.load_state_dict(torch.load(".\saved_models\ourmodel\combined_20\OurModel_20.pth", map_location=device))
model3.to(device)

# model2 = VGG16Pretrained()
# model2.load_state_dict(torch.load(model2_path, map_location=device))
# model2.to(device)

# model3 = ResNetPretrained()
# model3.load_state_dict(torch.load(model3_path, map_location=device))
# model3.to(device)

base_models = [model1, model2, model3]
    
meta_input_dim = len(base_models) * 10 
meta_model = MetaClassifier(meta_input_dim, 10)
meta_model.to(device)

stacking_ensemble = StackingEnsemble(base_models, meta_model, use_probs=True)
stacking_ensemble.to(device)

test_loader = load_data('./data/test', batch_size=128, shuffle=True, transform=normalized_simple_transform(), num_workers=1)
train_loader = load_data('./data/train', batch_size=128, shuffle=True, transform=normalized_simple_transform(), num_workers=1)
valid_loader = load_data('./data/valid', batch_size=512, shuffle=True, transform=normalized_simple_transform(), num_workers=1)

trainer = ModelTrainer(
    model=stacking_ensemble,
    train_loader=train_loader,
    valid_loader=valid_loader,
    device=device,
    optimizer_type="adam",
    learning_rate=0.001,
    weight_decay=0,
    save_dir=".saved_models/ensemble",
    log_file="ensemble.json"
)

trainer.train(epochs=10)