In [2]:
import argparse
from pathlib import Path
import pandas
import pickle
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import pandas as pd
import torch.nn as nn
from torchvision.models import efficientnet_v2_m
import torch

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = efficientnet_v2_m(weights=None)

        # Freeze the weights of the model
        for param in self.model.parameters():
            param.requires_grad = False

        # Unfreeze the last layer
        for param in self.model.classifier.parameters():
            param.requires_grad = True
        
        # Add a classfier layer for each level
        self.levels = []
        for level in range(1, 7):
            self.levels.append(nn.Sequential(
                nn.Linear(1000, 512),
                nn.ReLU(),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Linear(256, 128),
                nn.ReLU(),
                nn.Linear(128, 64),
                nn.ReLU(),
                nn.Linear(64, level)
            ))


    def forward(self, x):
        x = self.model(x)
        
        # Get the prediction for each level
        predictions = {}
        for i, level in enumerate(self.levels):
            predictions[f'level_{i + 1}'] = level(x)

        return predictions


In [7]:
model = Model()
model.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model.eval()

Model(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=