In [None]:

%load_ext autoreload
%autoreload 2

import copy
import pytorch_lightning as pl
import torchinfo

from PlantINaturalist2021DataModule import PlantINaturalist2021DataModule as imported_datamodule
from PlantINaturalist2021FineTuneDensenet201 import PlantINaturalist2021FineTuneDensenet201 as imported_model
from image_transformers import transform_autoaugment as imported_transform

config = {
    "model_name": imported_model.__name__,
    "num_classes": 250,
    "learning_rate": 0.01,
    "lr_decay_epoch_step_size": 5,
    "lr_decay_rate": 0.9,
    "num_trainable_layers": 3,
    "transform": imported_transform.__name__,
    "batch_size": 64
}


In [None]:

# sanity run
pl.Trainer(max_steps=5).fit(model=imported_model(config), datamodule=imported_datamodule(transform=imported_transform.get(), context = "retrain"))


In [None]:
# get model and datamodule and module summary
model = imported_model(config)

datamodule = imported_datamodule(transform=imported_transform.get(), context="train", batch_size=64, num_workers=2, pin_memory=True, data_dir="./")

torchinfo.summary(model)

In [None]:
# adjust datamodule config
datamodule = imported_datamodule(transform=imported_transform.get(), context="train", batch_size=64, num_workers=2, pin_memory=True, data_dir="./")

trainer = pl.Trainer(benchmark=True, max_time="00:00:03:00", accelerator='gpu', devices=1)

trainer.fit(model=model, datamodule=datamodule)

In [None]:
# more model info
torchinfo.summary(model, verbose = 2)

In [None]:
# initiate run
import wandb
import torch
from pytorch_lightning.loggers import WandbLogger


run = wandb.init(project='PlantINaturalist2', config=config)

wandb.save(f"{imported_model.__name__}.py")
wandb.save(f"{imported_datamodule.__name__}.py")
wandb.save("image_transformers.py")

In [None]:
# train model
wandb_logger = WandbLogger()
trainer = pl.Trainer(benchmark=True, logger=wandb_logger, max_time="00:02:00:00", accelerator='gpu', devices=1)

trainer.fit(model=model, datamodule=datamodule)

In [None]:
# save model as artifact to wandb
torch.save(model.state_dict(), 'model.pth')
artifact = wandb.Artifact('model', type='model')
artifact.add_file('model.pth')
run.log_artifact(artifact)

run.finish()

In [None]:
# check parameters
next(model.model.classifier.parameters())

Finetune and retrain

In [None]:
import wandb
artifact = run.use_artifact('pasoi0stefan/PlantINaturalist/model:v11', type='model')
artifact_dir = artifact.download()

In [None]:
from PlantINaturalist2021DataModule import PlantINaturalist2021DataModule as imported_datamodule
from PlantINaturalist2021FinetuneMobileNetv2 import PlantINaturalist2021FinetuneMobileNetv2 as imported_model
from image_transformers import transform_autoaugment as imported_transform

config = {
    "model_name": imported_model.__name__,
    "num_classes": 250,
    "learning_rate": 0.01,
    "lr_decay_epoch_step_size": 5,
    "lr_decay_rate": 0.9,
    "num_trainable_layers": 2,
}

In [None]:
model_artifact = imported_model(config)
model_artifact.load_state_dict(torch.load(f"{artifact_dir}/model.pth"))

In [None]:
#FINETUNE
wandb_logger = WandbLogger()
model_finetune = imported_model(config)
print(next(model_finetune.model.classifier.parameters())[:2])
model_finetune.load_state_dict(torch.load(f"{artifact_dir}/model.pth"))
print(next(model_finetune.model.classifier.parameters())[:2])
model_finetune.learning_rate = 0.001
model_finetune.configure_optimizers()
datamodule_finetune = imported_datamodule(transform=imported_transform.get(), context="finetune", batch_size=32, num_workers=1, pin_memory=True, data_dir="./")
trainer = pl.Trainer(benchmark=True, logger=wandb_logger, max_epochs = 20, accelerator='gpu', devices=1)
trainer.fit(model=model_finetune, datamodule=datamodule_finetune)
torch.save(model_finetune.state_dict(), 'finetuned_model.pth')
artifact = wandb.Artifact('finetuned_model', type='model')
artifact.add_file('finetuned_model.pth')
run.log_artifact(artifact)
run.finish()

In [None]:
#FINETUNE 2
wandb_logger = WandbLogger()
model_finetune = imported_model(config)
print(next(model_finetune.model.classifier.parameters())[:2])
model_finetune.load_state_dict(torch.load(f"{artifact_dir}/model.pth"))
print(next(model_finetune.model.classifier.parameters())[:2])
model_finetune.configure_optimizers()
datamodule_finetune = imported_datamodule(transform=imported_transform.get(), context="retrain", batch_size=32, num_workers=1, pin_memory=True, data_dir="./")
trainer = pl.Trainer(benchmark=True, logger=wandb_logger, max_epochs = 10, accelerator='gpu', devices=1)
trainer.fit(model=model_finetune, datamodule=datamodule_finetune)
torch.save(model_finetune.state_dict(), 'finetuned_model2.pth')
artifact = wandb.Artifact('finetuned_model2', type='model')
artifact.add_file('finetuned_model2.pth')
run.log_artifact(artifact)
run.finish()

In [None]:
#RETRAIN
wandb.finish()
run = wandb.init()
wandb_logger = WandbLogger()
model_retrain = imported_model(config)
datamodule_retrain = imported_datamodule(transform=imported_transform.get(), context="retrain", batch_size=32, num_workers=1, pin_memory=True, data_dir="./")
trainer = pl.Trainer(benchmark=True, logger=wandb_logger, max_epochs = 40, accelerator='gpu', devices=1)
trainer.fit(model=model_retrain, datamodule=datamodule_retrain)
torch.save(model_retrain.state_dict(), 'retrained_model.pth')
artifact = wandb.Artifact('retrained_model', type='model')
artifact.add_file('retrained_model.pth')
run.log_artifact(artifact)
run.finish()

In [None]:
next(model_artifact.model.classifier.parameters())