# Fine tuning Basenji2 with PyTorch
**Authorship:**
Adam Klie (last updated: *07/19/2023*)
***
**Description:**
This notebook demonstrates how to fine-tune a pre-trained Basenji2 model with PyTorch. The PyTorch implementation is detailed in the setup.ipynb notebook.
***

# Set-up

In [None]:
# General imports
import os
import json
import torch

# In case the PYTHONPATH is not set
import sys
sys.path.append('/cellar/users/aklie/opt/ml4gland/basenji2-pytorch')

# Clean cuda mem
import gc
torch.cuda.empty_cache()
gc.collect()

In [None]:
# Import the code for loading the PyTorch model
from basenji2_pytorch import Basenji2, params # or PLBasenji2 to use training parameters from Kelley et al. 2020

# Define the path to your downloaded weights
model_weights = '/cellar/users/aklie/projects/ML4GLand/models/Basenji/basenji2.pth'  # TODO: Change this to your path

In [None]:
# Open up the model config file
with open(params) as params_open:
    model_params = json.load(params_open)['model']

In [None]:
# to use a headless model e.g. for transfer learning
model_params.pop("head_human", None)

# Load in the PyTorch model

In [None]:
basenji2 = Basenji2(model_params)

In [None]:
basenji2.load_state_dict(torch.load(model_weights), strict=False)