In [1]:
import os
import torch

In [2]:
LOC = '../repo/DavidDov/tele_cyto_models/hub/checkpoints/'

In [3]:
os.listdir(LOC)

['vgg11_bn-6002323d.pth', 'mobilenet_v2-b0353104.pth']

In [4]:
# checkpoint = torch.load(os.path.join(LOC, 'vgg11_bn-6002323d.pth'))
checkpoint = torch.load(os.path.join(LOC, 'mobilenet_v2-b0353104.pth'))

In [5]:
from typing import Tuple
import torch
import torch.nn as nn
from torchvision import models


class VGG11Model(nn.Module):
    """VGG11 model class.

    Attributes:
        biases (torch.nn.Parameter): TODO: add description.
        vgg_features (torch.nn.Module): VGG feature map function.
        features (torch.nn.Sequential): hidden layers.
        classifier (torch.nn.Sequential): classifier.
    """
    def __init__(self, params: dict):
        """VGG11 model class constructor. 
        
        Args:
            params (dict): hyperparameters.
        """
        super(VGG11Model, self).__init__()
        
        # Output thresholds for Bethesda score prediction. Selected such that sigmoid(t1)-sigmoid(t2) = sigmoid(t2)-sigmoid(t3) = sigmoid(t3)-sigmoid(t4) = 0.2.
        if params['trainable_biases']:
            self.biases = nn.Parameter(torch.tensor([1.386, 0.405, -0.405, -1.386]))
        else:
            self.biases = torch.tensor([1.386, 0.405, -0.405, -1.386], requires_grad=False)
        
        # Load VGG11 model (without ImageNet pretraining).
        vgg11 = models.vgg11_bn(pretrained=params['pretrain'])
        
        # Load VGG11 weights. TODO(dd208): explain where the checkpoint comes from. TODO: Need to uncomment!
        # vgg11.load_state_dict(torch.load('../params_soft_link/vgg11_bn-6002323d.pth'))
        
        # Define VGG11 feature extractor.
        self.vgg_features = vgg11.features        
        # Further feature extraction.
        self.features = nn.Sequential(
            nn.Linear(512 * 4 * 4 , 16), #(512, 16)
            nn.BatchNorm1d(16), #nn.Dropout(),
            nn.ReLU(True))        
        # Final classifier. 
        self.classifier = nn.Sequential( 
            nn.BatchNorm1d(16), #nn.Dropout(),
            nn.ReLU(True),
            nn.Linear(16, 1))
        
               
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """VGG11 classifier forward pass.

        Args:
            z (torch.Tensor): [B, P, C, H, W] torch.tensor input. (B=batch, P=patches, CHW=image dims).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: _description_
        """
        # Reshape input: [B, P, C, H, W] -> [B*P, C, H, W].
        z = z.view([z.shape[0]*z.shape[1]] + list(z.shape[2:]))                              
        # Extract VGG features: [B*P, C, H, W] -> [B*P, 512, 4, 4].
        z = self.vgg_features(z)     
                     
        # Reshape: [B*P, 512, 4, 4] -> [B*P, 512*4*4]
        z = z.view(z.shape[0], -1)  
        # Further feature extraction: [B*P, 512*4*4] -> [B*P, 16].
        z = self.features(z)        

        # Classifier: [B*P, 16] -> [B*P, 1].
        malignancy_logits = self.classifier(z)

        # Bethesda score logits: [B*P, 1] -> [B*P, 4].
        bethesda_logits = malignancy_logits.repeat(1,4) - self.biases.repeat(malignancy_logits.shape[0], 1)

        return malignancy_logits, bethesda_logits


class MobileNetV2Model(nn.Module):
    """MobileNetV2 model class.

    Attributes:
        biases (torch.nn.Parameter): TODO: add description.
        vgg_features (torch.nn.Module): VGG feature map function.
        features (torch.nn.Sequential): hidden layers.
        classifier (torch.nn.Sequential): classifier.
    """
    def __init__(self, params: dict):
        """MobileNetV2 model class constructor. 
        
        Args:
            params (dict): hyperparams.
        """
        super(MobileNetV2Model, self).__init__()
        
        # Output thresholds for Bethesda score prediction. Selected such that sigmoid(t1)-sigmoid(t2) = sigmoid(t2)-sigmoid(t3) = sigmoid(t3)-sigmoid(t4) = 0.2.
        if params['trainable_biases']:
            self.biases = nn.Parameter(torch.tensor([1.386, 0.405, -0.405, -1.386]))
        else:
            self.biases = torch.tensor([1.386, 0.405, -0.405, -1.386],requires_grad=False)
        
        # Load MobileNetV2 model (without ImageNet pretraining).
        mobilenet = models.mobilenet_v2(pretrained=params['pretrain'])
        
        # Define MobileNetV2 feature extractor.
        self.mobilenet_features = mobilenet.features

        # Spatial average pooling operation: [B,C=1280,H=4,W=4] -> [B,C=1280,H=1,W=1].
        self.avg_pool_2d = torch.nn.AvgPool2d(kernel_size=4)

        # Final classifier.
        self.classifier = nn.Linear(in_features=1280,out_features=1)        
               
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """VGG11 classifier forward pass.

        Args:
            z (torch.Tensor): [B, P, C, H, W] torch.tensor input. (B=batch, P=patches, CHW=image dims).

        Returns:
            malignancy_logits: [B*P, 1] float torch.tensor of malignancy logits.
            bethesda_logits: [B*P, 4] float torch.tensor of bethesda logits.
        """
        # Reshape input: [B, P, C, H, W] -> [B*P, C, H, W].
        z = z.view([z.shape[0]*z.shape[1]] + list(z.shape[2:]))

        # Extract mobilenet features: [B*P, C, H, W] -> [B*P, 1280, 4, 4].
        z = self.mobilenet_features(z)

        # Spatial average pool: [B*P, 1280, 4, 4] -> [B*P, 1280, 1, 1].
        z = self.avg_pool_2d(z)

        # Reshape: [B*P, 1280, 1, 1] -> [B*P, 1280]
        z = z.view(z.shape[0], -1)   

        # Classifier: [B*P, 1280] -> [B*P, 1].
        malignancy_logits = self.classifier(z)

        # Bethesda score logits: [B*P, 1] -> [B*P, 4].
        bethesda_logits = malignancy_logits.repeat(1,4) - self.biases.repeat(malignancy_logits.shape[0], 1)

        return malignancy_logits, bethesda_logits

In [6]:
params = {'pretrain': 0,
          'trainable_biases': 1}

In [7]:
model = VGG11Model(params);



In [8]:
# model = models.vgg11_bn(pretrained=1);
model = models.mobilenet_v2(pretrained=params['pretrain']);

In [9]:
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [10]:
model.features

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): InvertedResidual(
    (conv): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): InvertedResidual(
    (conv): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (

In [14]:
import torch
import torchvision
from torchvision.transforms import v2

test_transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToTensor(),
    v2.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])



In [20]:
from PIL import Image
image = Image.open('/home/quan/work/thyroid/data/NOH/001.NguyeThiLan- Right/IMG_20221219_101351.jpg')
image = test_transform(image)
image = image.unsqueeze(0)
# image = image.to(device)
output = model.features(image)
# output = output.argmax(dim=1).cpu().numpy()[0]

In [27]:
image.shape

torch.Size([1, 3, 224, 224])

In [21]:
output.shape

torch.Size([1, 1280, 7, 7])

In [22]:
avg_pool_2d = torch.nn.AvgPool2d(kernel_size=4)

In [24]:
x = avg_pool_2d(output)

In [25]:
x.shape

torch.Size([1, 1280, 1, 1])

In [None]:
class QuanVGG11Model(nn.Module):
    """VGG11 model class.

    Attributes:
        biases (torch.nn.Parameter): TODO: add description.
        vgg_features (torch.nn.Module): VGG feature map function.
        features (torch.nn.Sequential): hidden layers.
        classifier (torch.nn.Sequential): classifier.
    """
    def __init__(self, params: dict):
        """VGG11 model class constructor. 
        
        Args:
            params (dict): hyperparameters.
        """
        super(QuanVGG11Model, self).__init__()
        
        # Output thresholds for Bethesda score prediction. Selected such that sigmoid(t1)-sigmoid(t2) = sigmoid(t2)-sigmoid(t3) = sigmoid(t3)-sigmoid(t4) = 0.2.
        if params['trainable_biases']:
            self.biases = nn.Parameter(torch.tensor([1.386, 0.405, -0.405, -1.386]))
        else:
            self.biases = torch.tensor([1.386, 0.405, -0.405, -1.386], requires_grad=False)
        
        # Load VGG11 model (without ImageNet pretraining).
        vgg11 = models.vgg11_bn(pretrained=params['pretrain'])
        
        # Load VGG11 weights. TODO(dd208): explain where the checkpoint comes from. TODO: Need to uncomment!
        # vgg11.load_state_dict(torch.load('../params_soft_link/vgg11_bn-6002323d.pth'))
        
        # Define VGG11 feature extractor.
        self.vgg_features = vgg11.features        
        # Further feature extraction.
        self.features = nn.Sequential(
            nn.Linear(512 * 4 * 4 , 16), #(512, 16)
            nn.BatchNorm1d(16), #nn.Dropout(),
            nn.ReLU(True))        
        # Final classifier. 
        self.classifier = nn.Sequential( 
            nn.BatchNorm1d(16), #nn.Dropout(),
            nn.ReLU(True),
            nn.Linear(16, 1))
        
               
    def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """VGG11 classifier forward pass.

        Args:
            z (torch.Tensor): [B, P, C, H, W] torch.tensor input. (B=batch, P=patches, CHW=image dims).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: _description_
        """
        # Reshape input: [B, P, C, H, W] -> [B*P, C, H, W].
        z = z.view([z.shape[0]*z.shape[1]] + list(z.shape[2:]))                              
        # Extract VGG features: [B*P, C, H, W] -> [B*P, 512, 4, 4].
        z = self.vgg_features(z)     
                     
        # Reshape: [B*P, 512, 4, 4] -> [B*P, 512*4*4]
        z = z.view(z.shape[0], -1)  
        # Further feature extraction: [B*P, 512*4*4] -> [B*P, 16].
        z = self.features(z)        

        # Classifier: [B*P, 16] -> [B*P, 1].
        malignancy_logits = self.classifier(z)

        # Bethesda score logits: [B*P, 1] -> [B*P, 4].
        bethesda_logits = malignancy_logits.repeat(1,4) - self.biases.repeat(malignancy_logits.shape[0], 1)

        return malignancy_logits, bethesda_logits

In [1]:
mode1 = VGG11Model()

NameError: name 'VGG11Model' is not defined