# IONet model

---

In [13]:
! pip install timm
! pip install transformers

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl.metadata (38 kB)
Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
   ---------------------------------------- 0.0/2.2 MB ? eta -:--:--
   ---- ----------------------------------- 0.3/2.2 MB 5.2 MB/s eta 0:00:01
   ------ --------------------------------- 0.4/2.2 MB 4.0 MB/s eta 0:00:01
   ------------- -------------------------- 0.8/2.2 MB 5.4 MB/s eta 0:00:01
   ------------------ --------------------- 1.1/2.2 MB 5.6 MB/s eta 0:00:01
   ------------------------ --------------- 1.4/2.2 MB 5.9 MB/s eta 0:00:01
   ----------------------------- ---------- 1.7/2.2 MB 5.6 MB/s eta 0:00:01
   --------------------------------- ------ 1.9/2.2 MB 5.4 MB/s eta 0:00:01
   ------------------------------------- -- 2.1/2.2 MB 5.5 MB/s eta 0:00:01
   ---------------------------------------  2.2/2.2 MB 5.5 MB/s eta 0:00:01
   ---------------------------------------  2.2/2.2 MB 5.5 MB/s eta 0:00:01
   ---------------------------------------- 2

In [2]:
# import lib
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import tqdm
from tqdm.auto import trange, tqdm

# import pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# to get CFIAR10 dataset
from torchvision import transforms
import torchvision
import torchvision.transforms as transforms

# to import pretrained models
from transformers import AutoImageProcessor, MobileNetV1Model
import timm

# import sklearn
from sklearn.model_selection import train_test_split

# set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


---
## Misc

In [3]:
def cvtImg(img):
    img = img.permute([0, 2, 3, 1])
    img = img - img.min()
    img = (img / img.max())
    return img.numpy().astype(np.float32)


def show_examples(x):
    plt.figure(figsize=(10, 10))
    imgs = cvtImg(x)
    for i in range(25):
        plt.subplot(5, 5, i+1)
        plt.imshow(imgs[i])
        plt.axis('off')

In [4]:
class ChannelShuffle(nn.Module):
    def __init__(self, group=1):
        assert group > 1
        super(ChannelShuffle, self).__init__()
        self.group = group

    def forward(self, x):
        """https://github.com/Randl/ShuffleNetV2-pytorch/blob/master/model.py
        """
        batchsize, num_channels, height, width = x.data.size()
        assert (num_channels % self.group == 0)
        channels_per_group = num_channels // self.group
        # reshape
        x = x.view(batchsize, self.group, channels_per_group, height, width)
        # transpose
        # - contiguous() required if transpose() is used before view().
        #   See https://github.com/pytorch/pytorch/issues/764
        x = torch.transpose(x, 1, 2).contiguous()
        # flatten
        x = x.view(batchsize, -1, height, width)
        return x


class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


class FBdecoderLayer(nn.Module):
    def __init__(self, C_in, C_out, kernel_size, stride,
                 expansion, group, bn=False):
        super(FBdecoderLayer, self).__init__()
        assert not bn, "not support bn for now"
        bias_flag = not bn
        if kernel_size == 1:
            padding = 0
        elif kernel_size == 3:
            padding = 1
        elif kernel_size == 5:
            padding = 2
        elif kernel_size == 7:
            padding = 3
        else:
            raise ValueError("Not supported kernel_size %d" % kernel_size)
        if group == 1:
            self.op = nn.Sequential(
                nn.Conv2d(C_in, C_out*expansion, 1, stride=1, padding=0,
                          groups=group, bias=bias_flag),
                nn.ReLU(inplace=False),
                nn.Conv2d(C_out*expansion, C_out*expansion, kernel_size, stride=stride,
                          padding=padding, groups=C_out*expansion, bias=bias_flag),
                nn.ReLU(inplace=False),
                nn.Conv2d(C_out*expansion, C_out, 1, stride=1, padding=0,
                          groups=group, bias=bias_flag)
            )
        else:
            self.op = nn.Sequential(
                nn.Conv2d(C_in, C_out*expansion, 1, stride=1, padding=0,
                          groups=group, bias=bias_flag),
                nn.ReLU(inplace=False),
                ChannelShuffle(group),
                nn.Conv2d(C_out*expansion, C_out*expansion, kernel_size, stride=stride,
                          padding=padding, groups=C_out*expansion, bias=bias_flag),
                nn.ReLU(inplace=False),
                nn.Conv2d(C_out*expansion, C_out, 1, stride=1, padding=0,
                          groups=group, bias=bias_flag),
                ChannelShuffle(group)
            )
        res_flag = ((C_in == C_out) and (stride == 1))
        self.res_flag = res_flag
        if not res_flag:
            if stride == 2:
                self.trans = nn.Conv2d(C_in, C_out, 3, stride=2,
                                       padding=1)
            elif stride == 1:
                self.trans = nn.Conv2d(C_in, C_out, 1, stride=1,
                                       padding=0)
            else:
                raise ValueError("Wrong stride %d provided" % stride)

    def forward(self, x):
        if self.res_flag:
            return self.op(x) + x
        else:
            return self.op(x) + self.trans(x)

In [5]:
model = FBdecoderLayer(C_in=256, C_out=256, kernel_size=3,
                       stride=1, expansion=1, group=1, bn=False)

# random input
x = torch.randn(1, 256, 32, 32)

# forward
y = model(x)

print(y.shape)

model = FBdecoderLayer(C_in=512, C_out=256, kernel_size=3,
                       stride=1, expansion=1, group=1, bn=False)
cnn_model = nn.Conv2d(512, 256, 3, stride=1, padding=1)

# random input
x = torch.randn(1, 512, 32, 32)

# forward
y = model(x)
y_cnn = cnn_model(x)

print(y.shape)
print(y_cnn.shape)

print(model)

torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 256, 32, 32])
FBdecoderLayer(
  (op): Sequential(
    (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=256)
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (trans): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)


In [29]:
class ChannelShuffle(nn.Module):
    def __init__(self, group=1):
        assert group > 1
        super(ChannelShuffle, self).__init__()
        self.group = group

    def forward(self, x):
        """https://github.com/Randl/ShuffleNetV2-pytorch/blob/master/model.py
        """
        batchsize, num_channels, height, width = x.data.size()
        assert (num_channels % self.group == 0)
        channels_per_group = num_channels // self.group
        # reshape
        x = x.view(batchsize, self.group, channels_per_group, height, width)
        # transpose
        # - contiguous() required if transpose() is used before view().
        #   See https://github.com/pytorch/pytorch/issues/764
        x = torch.transpose(x, 1, 2).contiguous()
        # flatten
        x = x.view(batchsize, -1, height, width)
        return x


class identity(nn.Module):
    def __init__(self):
        super(identity, self).__init__()

    def forward(self, x):
        return x

# based on FB block


class DecoderLayer(nn.Module):
    def __init__(self, in_channels, out_channels, expansion=3):

        super(DecoderLayer, self).__init__()
        self.conv1 = nn.ConvTranspose2d(
            in_channels, in_channels*expansion, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(in_channels*expansion)
        self.relu1 = nn.ReLU(inplace=True)

        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')

        self.conv2 = nn.Conv2d(
            in_channels*expansion, in_channels*expansion, kernel_size=5, padding=2, stride=1)
        self.bn2 = nn.BatchNorm2d(in_channels*expansion)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(in_channels*expansion,
                               out_channels, kernel_size=1, stride=1)
        self.bn3 = nn.BatchNorm2d(out_channels)

        # self.skip_connection = None

        # if in_channels != out_channels:
        #   self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=1)
        # else:
        #   self.skip_connection = nn.Identity()

    def forward(self, x):
        in_x = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.upsample(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        # TODO add skip connection

        # print("layer output:", x.shape)
        # print("skip connection:", self.skip_connection(in_x).shape)

        # add skip connection
        # x = x + self.skip_connection(in_x)

        return x


# random input
x = torch.randn(1, 256, 32, 32)

# forward
model = DecoderLayer(256, 256)

y = model(x)

print(y.shape)

torch.Size([1, 256, 64, 64])


---
## MobileNetV2: Inverted Residuals and Linear Bottlenecks

In [32]:
# Initialize the MobileNetV2 model
image_processor = AutoImageProcessor.from_pretrained(
    "google/mobilenet_v1_1.0_224")
mobilenet_v1_model = MobileNetV1Model.from_pretrained(
    "google/mobilenet_v1_1.0_224")

mobilenet_v1_model

MobileNetV1Model(
  (conv_stem): MobileNetV1ConvLayer(
    (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (layer): ModuleList(
    (0): MobileNetV1ConvLayer(
      (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (1): MobileNetV1ConvLayer(
      (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (2): MobileNetV1ConvLayer(
      (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), groups=64, bias=False)
      (normalization): BatchNorm2d(64, eps=0.001, momentum=0.999

In [33]:
# save
mobilenet_v1_model.save_pretrained("../Pretrained_models/mobilenet_v1")

mobilenet_v1_model

MobileNetV1Model(
  (conv_stem): MobileNetV1ConvLayer(
    (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
    (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
    (activation): ReLU6()
  )
  (layer): ModuleList(
    (0): MobileNetV1ConvLayer(
      (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (1): MobileNetV1ConvLayer(
      (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
      (activation): ReLU6()
    )
    (2): MobileNetV1ConvLayer(
      (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), groups=64, bias=False)
      (normalization): BatchNorm2d(64, eps=0.001, momentum=0.999

In [34]:
# get first layer
first_layer = mobilenet_v1_model.conv_stem

print(first_layer)

MobileNetV1ConvLayer(
  (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
  (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
  (activation): ReLU6()
)


In [35]:
import torch.nn as nn

model = mobilenet_v1_model

# Define the number of layers per block
layers_per_block = 4

# Calculate the total number of blocks
total_layers = len(model.layer)
num_blocks = total_layers // layers_per_block

mobilenet_seq_blocks = []

# Iterate through the layers and divide them into blocks
for block_idx in range(num_blocks):
    start_idx = block_idx * layers_per_block
    end_idx = (block_idx + 1) * layers_per_block
    block_layers = model.layer[start_idx:end_idx]

    # Create a block
    block = nn.Sequential(*block_layers)
    setattr(model, f"block{block_idx}", block)

    print(f"Block {block_idx}:")

    # Print input, output, and kernel size for each layer
    for layer_idx, layer in enumerate(block, start=start_idx):
        if isinstance(layer.convolution, nn.Conv2d):
            print(f"  Layer {layer_idx}: Input: {layer.convolution.in_channels}, Output: {layer.convolution.out_channels}, Kernel: {layer.convolution.kernel_size}")

    mobilenet_seq_blocks.append(block)

# add the remaining layers
start_idx = num_blocks * layers_per_block

# Create a block
block = nn.Sequential(*model.layer[start_idx:])
setattr(model, f"block{num_blocks}", block)

print(f"Block {num_blocks}:")
for layer_idx, layer in enumerate(block, start=start_idx):
    if isinstance(layer.convolution, nn.Conv2d):
        print(f"  Layer {layer_idx}: Input: {layer.convolution.in_channels}, Output: {layer.convolution.out_channels}, Kernel: {layer.convolution.kernel_size}")

mobilenet_seq_blocks.append(block)

mobilenet_seq_blocks

Block 0:
  Layer 0: Input: 32, Output: 32, Kernel: (3, 3)
  Layer 1: Input: 32, Output: 64, Kernel: (1, 1)
  Layer 2: Input: 64, Output: 64, Kernel: (3, 3)
  Layer 3: Input: 64, Output: 128, Kernel: (1, 1)
Block 1:
  Layer 4: Input: 128, Output: 128, Kernel: (3, 3)
  Layer 5: Input: 128, Output: 128, Kernel: (1, 1)
  Layer 6: Input: 128, Output: 128, Kernel: (3, 3)
  Layer 7: Input: 128, Output: 256, Kernel: (1, 1)
Block 2:
  Layer 8: Input: 256, Output: 256, Kernel: (3, 3)
  Layer 9: Input: 256, Output: 256, Kernel: (1, 1)
  Layer 10: Input: 256, Output: 256, Kernel: (3, 3)
  Layer 11: Input: 256, Output: 512, Kernel: (1, 1)
Block 3:
  Layer 12: Input: 512, Output: 512, Kernel: (3, 3)
  Layer 13: Input: 512, Output: 512, Kernel: (1, 1)
  Layer 14: Input: 512, Output: 512, Kernel: (3, 3)
  Layer 15: Input: 512, Output: 512, Kernel: (1, 1)
Block 4:
  Layer 16: Input: 512, Output: 512, Kernel: (3, 3)
  Layer 17: Input: 512, Output: 512, Kernel: (1, 1)
  Layer 18: Input: 512, Output: 512,

[Sequential(
   (0): MobileNetV1ConvLayer(
     (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
     (normalization): BatchNorm2d(32, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
     (activation): ReLU6()
   )
   (1): MobileNetV1ConvLayer(
     (convolution): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
     (activation): ReLU6()
   )
   (2): MobileNetV1ConvLayer(
     (convolution): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), groups=64, bias=False)
     (normalization): BatchNorm2d(64, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)
     (activation): ReLU6()
   )
   (3): MobileNetV1ConvLayer(
     (convolution): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
     (normalization): BatchNorm2d(128, eps=0.001, momentum=0.9997, affine=True, track_running_stats=True)

---
## Unet class

en_stim: 3 -> 32

en0: 32 -> 128

en1: 128 -> 256

en2: 256 -> 512

en3: 512 -> 512

en4: 512 -> 512

en5: 512 -> 1024

en6: 1024 -> 1024

de6: 1024 -> 1024

de5: 1024 + 1024 -> 512

de4: 512 + 512 -> 512

de3: 512 + 512 -> 512

de2: 512 + 512 -> 256

de1: 256 + 256 -> 128

de0: 128 + 128 -> 32

out_stim: 32 -> 3

---

In [38]:
class Unet(nn.Module):
	def __init__(self):
		super(Unet, self).__init__()

		# Encoder layers
		
		# fist layer is moblienet_v1 layer 1
		self.en_stim = mobilenet_v1_model.conv_stem
		
		# now we have the blocks
		self.en0 = mobilenet_seq_blocks[0]
		self.en1 = mobilenet_seq_blocks[1]
		self.en2 = mobilenet_seq_blocks[2]
		self.en3 = mobilenet_seq_blocks[3]
		self.en4 = mobilenet_seq_blocks[4]
		self.en5 = mobilenet_seq_blocks[5]
		self.en6 = mobilenet_seq_blocks[6]

		# Decoder layers
		self.de6 = DecoderLayer(1024, 1024)
		self.de5 = DecoderLayer(1024+1024, 512)
		self.de4 = DecoderLayer(512+512, 256)
		self.de3 = DecoderLayer(256+256, 128)
		self.de2 = DecoderLayer(128+128, 64)
		self.de1 = DecoderLayer(64+64, 32)
		self.de0 = DecoderLayer(32+32, 32)

		# self.de6 = FBdecoderLayer(C_in=1024,C_out=1024,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de5 = FBdecoderLayer(C_in=1024+1024,C_out=512,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de4 = FBdecoderLayer(C_in=512+512,C_out=256,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de3 = FBdecoderLayer(C_in=256+256,C_out=128,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de2 = FBdecoderLayer(C_in=128+128,C_out=64,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de1 = FBdecoderLayer(C_in=64+64,C_out=32,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		# self.de0 = FBdecoderLayer(C_in=32+32,C_out=32,kernel_size=5,stride=1,expansion=1,group=1,bn=False)
		
		# last layer
		self.out_stim = nn.Sequential(
			nn.Conv2d(32, 3, kernel_size=3, stride=2, bias=False),
			nn.BatchNorm2d(3),
			nn.ReLU6()
		)
		
		# Output layer
		self.out_stim = nn.Conv2d(32, 3, kernel_size=3, padding=1)

	def forward(self, x):
		# Encoder pass
		en_stim_out = F.relu(self.en_stim(x))
		print("en_stim_out shape:", en_stim_out.shape)

		en0_out = F.relu(self.en0(en_stim_out))
		print("en0_out shape:", en0_out.shape)

		en1_out = F.relu(self.en1(en0_out))
		print("en1_out shape:", en1_out.shape)

		en2_out = F.relu(self.en2(en1_out))
		print("en2_out shape:", en2_out.shape)

		en3_out = F.relu(self.en3(en2_out))
		print("en3_out shape:", en3_out.shape)

		en4_out = F.relu(self.en4(en3_out))
		print("en4_out shape:", en4_out.shape)

		en5_out = F.relu(self.en5(en4_out))
		print("en5_out shape:", en5_out.shape)

		en6_out = F.relu(self.en6(en5_out))
		print("en6_out shape:", en6_out.shape)

		# Decoder pass with skip connections
		de6_out = F.relu(self.de6(en6_out))
		print("\nde6_out shape:", de6_out.shape)
		print("en5_out shape:", en5_out.shape)
  
		print("de5_in shape:", torch.cat([de6_out, en5_out], dim=1).shape)

		de5_out = F.relu(self.de5(torch.cat([de6_out, en5_out], dim=1)))
		print("de5_out shape:", de5_out.shape)

		print("en4_out shape:", en4_out.shape)
		print("de5_in shape:", torch.cat([de5_out, en4_out], dim=1).shape)
		de4_out = F.relu(self.de4(torch.cat([de5_out, en4_out], dim=1)))
		print("de4_out shape:", de4_out.shape)

		de3_out = F.relu(self.de3(torch.cat([de4_out, en3_out], dim=1)))
		print("de3_out shape:", de3_out.shape)

		de2_out = F.relu(self.de2(torch.cat([de3_out, en2_out], dim=1)))
		print("de2_out shape:", de2_out.shape)

		de1_out = F.relu(self.de1(torch.cat([de2_out, en1_out], dim=1)))
		print("de1_out shape:", de1_out.shape)

		de0_out = F.relu(self.de0(torch.cat([de1_out, en0_out], dim=1)))
		print("de0_out shape:", de0_out.shape)

		# Output prediction
		output = self.out_stim(de0_out)
		print("output shape:", output.shape)
		return output


	# def forward(self, x):
	#     # Encoder pass
	#     en_stim_out = F.relu(self.en_stim(x))
		
	#     en0_out = F.relu(self.en0(en_stim_out))
	#     en1_out = F.relu(self.en1(en0_out))
	#     en2_out = F.relu(self.en2(en1_out))
	#     en3_out = F.relu(self.en3(en2_out))
	#     en4_out = F.relu(self.en4(en3_out))
	#     en5_out = F.relu(self.en5(en4_out))
	#     en6_out = F.relu(self.en6(en5_out))

	#     # Decoder pass with skip connections
	#     de6_out = F.relu(self.de6(en6_out))
	#     de5_out = F.relu(self.de5(torch.cat([de6_out, en5_out], dim=1)))
	#     de4_out = F.relu(self.de4(torch.cat([de5_out, en4_out], dim=1)))
	#     de3_out = F.relu(self.de3(torch.cat([de4_out, en3_out], dim=1)))
	#     de2_out = F.relu(self.de2(torch.cat([de3_out, en2_out], dim=1)))
	#     de1_out = F.relu(self.de1(torch.cat([de2_out, en1_out], dim=1)))
	#     de0_out = F.relu(self.de0(torch.cat([de1_out, en0_out], dim=1)))

	#     # Output prediction
	#     output = self.out_stim(de0_out)
	#     return output


In [39]:
model = Unet().to(device)

# random input
x = torch.randn(1, 3, 224, 224)

# forward
y = model(x.to(device))

print(y.shape)

en_stim_out shape: torch.Size([1, 32, 112, 112])
en0_out shape: torch.Size([1, 128, 56, 56])
en1_out shape: torch.Size([1, 256, 28, 28])
en2_out shape: torch.Size([1, 512, 14, 14])
en3_out shape: torch.Size([1, 512, 14, 14])
en4_out shape: torch.Size([1, 512, 14, 14])
en5_out shape: torch.Size([1, 1024, 7, 7])
en6_out shape: torch.Size([1, 1024, 7, 7])

de6_out shape: torch.Size([1, 1024, 14, 14])
en5_out shape: torch.Size([1, 1024, 7, 7])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 7 for tensor number 1 in the list.

---
## FB net

In [18]:
fbnet_100_model = timm.create_model('fbnetc_100', pretrained=True)

fbnet_100_model

EfficientNet(
  (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): ReLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): InvertedResidual(
        (conv_pw): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNormAct2d(
          16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): ReLU(inplace=True)
        )
        (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=16, bias=False)
        (bn2): BatchNormAct2d(
          16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): ReLU(inplace=True)
        )
        (se): Identity()
        (conv_pwl): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
     

In [None]:
# save the model
fbnet_100_model.save_pretrained("../Pretrained_models/fbnet_100")