In [None]:
import os
import torch
import torch.nn as nn
from torchvision.models import resnet50, ResNet50_Weights
from urllib.request import urlretrieve

As the full fine tunned model was too heavy to load on github to fully support the inference, what I have provided, is the `resnet50_finetune_diff.pth` file, which only keeps the weights of the very last two layers (the fc and layer4); the implementation is as follows: (you don't need to run these two cell below)

In [2]:
def save_finetune_diff(base_model, finetuned_model, output_path="resnet50_finetune_diff.pth", rtol=1e-6, atol=1e-8):
    """
    Compare base model and fine-tuned model weights, save only changed weights.
    
    Args:
        base_model (torch.nn.Module): Original pretrained model.
        finetuned_model (torch.nn.Module): Fine-tuned model.
        output_path (str): Where to save the diff file.
        rtol, atol (float): Tolerances for considering parameters 'equal'.
    """
    base_state = base_model.state_dict()
    fine_state = finetuned_model.state_dict()

    diff_state = {}
    for name, fine_tensor in fine_state.items():
        base_tensor = base_state.get(name)
        # Save if base doesn't have it OR values are different
        if base_tensor is None or not torch.allclose(fine_tensor, base_tensor, rtol=rtol, atol=atol):
            diff_state[name] = fine_tensor.cpu()

    torch.save(diff_state, output_path)
    print(f"[INFO] Saved fine-tune diff with {len(diff_state)} tensors to {output_path}")


In [4]:
base_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
base_model.fc = nn.Linear(base_model.fc.in_features, 17)  # match output classes

finetuned_model = resnet50(weights=None)
finetuned_model.fc = nn.Linear(finetuned_model.fc.in_features, 17)
finetuned_model.load_state_dict(torch.load("resnet50_custom_17class.pth"))

save_finetune_diff(base_model, finetuned_model, output_path="resnet50_finetune_diff.pth")

  finetuned_model.load_state_dict(torch.load("resnet50_custom_17class.pth"))


[INFO] Saved fine-tune diff with 191 tensors to resnet50_finetune_diff.pth


In order to load the finetuned model's weights, you could simply run this cell below:

In [None]:
def load_finetuned_resnet50(
    num_classes=17,
    diff_url="https://raw.githubusercontent.com/AmirHossienAfshar/cv-noise-denoise/master/saved_models/resnet50_finetune_diff.pth",
    local_diff_path="resnet50_finetune_diff.pth",
    base_model_path=None,
    device=None
):
    """
    Loads ResNet50, applies fine-tuned diff weights.
    
    Args:
        num_classes (int): Number of output classes.
        diff_url (str): URL to download fine-tuned diff weights if not present locally.
        local_diff_path (str): Path to store/load the diff weights.
        base_model_path (str or None): Optional local path to a full base model checkpoint.
                                       If None, use TorchVision pretrained weights.
        device (str or None): Device to load model onto.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    if base_model_path and os.path.exists(base_model_path):
        print(f"[INFO] Loading base model from {base_model_path}")
        model = resnet50(weights=None)  # initialize empty
        checkpoint = torch.load(base_model_path, map_location=device)
        model.load_state_dict(checkpoint)
    else:
        print("[INFO] Loading base model from TorchVision (ImageNet1K weights)")
        model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)

    model.fc = nn.Linear(model.fc.in_features, num_classes)

    if not os.path.exists(local_diff_path):
        print(f"[INFO] Diff file not found. Downloading from {diff_url} ...")
        urlretrieve(diff_url, local_diff_path)
        print(f"[INFO] Downloaded diff to {local_diff_path}")

    diff_state = torch.load(local_diff_path, map_location=device)
    model_state = model.state_dict()
    model_state.update(diff_state)
    model.load_state_dict(model_state)

    model = model.to(device)
    return model


now, you considering your options (either have the base (or finetuned) weights local or not), you can have the inference model as easy as following:

In [None]:
model = load_finetuned_resnet50()
# or if you had the either of the weights, you could pass the .pth path as the parameter:
# model = load_finetuned_resnet50(
#     base_model_path="path/to/your/base_resnet50.pth"
# )