pip install timm
pip install segmentation-models-pytorch

In [3]:
import torch
import timm
from PIL import Image
from torchvision import transforms as T
import numpy as np

# timm

In [4]:
model = timm.create_model('mobilenetv3_large_100', pretrained=True, features_only=True)
model.eval()


Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth" to C:\Users\InColumi/.cache\torch\hub\checkpoints\mobilenetv3_large_100_ra-f55367f5.pth


MobileNetV3Features(
  (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): Hardswish()
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (bn1): BatchNormAct2d(
          16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): ReLU(inplace=True)
        )
        (se): Identity()
        (conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): Identity()
        )
        (drop_path): Identity()
      )
    )
    (1): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(16

In [5]:
img_path = './2.jpg'

In [6]:
img = Image.open(img_path)
trans = T.Compose([T.Resize((224, 224)), T.ToTensor()])
timg = trans(img)
out = model(timg.unsqueeze(0))
pred = out[0].argmax()

In [7]:
for i in out:
    print(i.shape)

torch.Size([1, 16, 112, 112])
torch.Size([1, 24, 56, 56])
torch.Size([1, 40, 28, 28])
torch.Size([1, 112, 14, 14])
torch.Size([1, 960, 7, 7])


In [8]:
with open('imgnet.txt', 'r') as f:
    classes = [line.rstrip() for line in f]

In [9]:
print(pred)

tensor(62831)


In [10]:
#classes[pred]

In [11]:
example_inputs = torch.rand(1, 3, 224, 224, requires_grad=True)
traced_foo = torch.jit.trace(model, example_inputs)
traced_foo.save('mobilenetv3_large_100.jit')

  module._c._create_method_from_trace(


In [12]:
load_traced_model = torch.jit.load('mobilenetv3_large_100.jit')
load_traced_model.eval()

RecursiveScriptModule(
  original_name=MobileNetV3Features
  (conv_stem): RecursiveScriptModule(original_name=Conv2d)
  (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
  (act1): RecursiveScriptModule(original_name=Hardswish)
  (blocks): RecursiveScriptModule(
    original_name=Sequential
    (0): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=DepthwiseSeparableConv
        (conv_dw): RecursiveScriptModule(original_name=Conv2d)
        (bn1): RecursiveScriptModule(
          original_name=BatchNormAct2d
          (drop): RecursiveScriptModule(original_name=Identity)
          (act): RecursiveScriptModule(original_name=ReLU)
        )
        (se): RecursiveScriptModule(original_name=Identity)
        (conv_pw): RecursiveScriptModule(original_name=Conv2d)
        (bn2): RecursiveScriptModule(
          original_name=BatchNormAct2d
          (drop): RecursiveScriptModule(original_name=Identity)
          (act): Recursi

# SMP

In [15]:
import segmentation_models_pytorch as smp

model_with_imagenet = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=2,                      
)

out = model_with_imagenet(timg.unsqueeze(0))
pred = torch.softmax(out, dim=1)[0][0]
mask = (pred.cpu().detach().numpy()*255).astype(np.uint8)
Image.fromarray(mask).save('out_1.png')

example_inputs = torch.rand(1, 3, 224, 224, requires_grad=True)
traced_foo = torch.jit.trace(model_with_imagenet, example_inputs)
traced_foo.save('resnet18.jit')

load_traced_model = torch.jit.load('resnet18.jit')
load_traced_model.eval()

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\InColumi/.cache\torch\hub\checkpoints\resnet18-5c106cde.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

  if h % output_stride != 0 or w % output_stride != 0:


RecursiveScriptModule(
  original_name=Unet
  (encoder): RecursiveScriptModule(
    original_name=ResNetEncoder
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
    (relu): RecursiveScriptModule(original_name=ReLU)
    (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
    (layer1): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=BasicBlock
        (conv1): RecursiveScriptModule(original_name=Conv2d)
        (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
        (relu): RecursiveScriptModule(original_name=ReLU)
        (conv2): RecursiveScriptModule(original_name=Conv2d)
        (bn2): RecursiveScriptModule(original_name=BatchNorm2d)
      )
      (1): RecursiveScriptModule(
        original_name=BasicBlock
        (conv1): RecursiveScriptModule(original_name=Conv2d)
        (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
        (rel

In [None]:
import segmentation_models_pytorch as smp

model_with_ssl = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="ssl",     
    in_channels=3,                  
    classes=2,                      
)

out = model_with_ssl(timg.unsqueeze(0))
pred = torch.softmax(out, dim=1)[0][0]
mask = (pred.cpu().detach().numpy()*255).astype(np.uint8)
Image.fromarray(mask).save('out_2.png')

example_inputs = torch.rand(1, 3, 224, 224, requires_grad=True)
traced_model = torch.jit.trace(model_with_imagenet, example_inputs)
traced_model.save('resnet18.2.jit')

load_traced_model = torch.jit.load('resnet18.2.jit')
load_traced_model.eval()

Downloading: "https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_supervised_resnet18-d92f0530.pth" to C:\Users\InColumi/.cache\torch\hub\checkpoints\semi_supervised_resnet18-d92f0530.pth


  0%|          | 0.00/44.6M [00:00<?, ?B/s]

## Cохраните веса (state_dict) encoder части сегментационной модели в отдельный файл

In [40]:
model_with_imagenet = smp.Unet(
    encoder_name="resnet18",
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=2,                      
)

em = timm.create_model('resnet18', features_only=True)
torch.save(model_with_imagenet.encoder.state_dict(), 'tmp_resnet18.pt')
em.load_state_dict(torch.load('tmp_resnet18.pt'))

<All keys matched successfully>