In [1]:
import matplotlib.pyplot as plt
import timm
import torch
import torch.nn as nn
from PIL import Image
from my_utils import find_files_by_ext

In [None]:
print(timm.__version__)  # 1.0.21
print(len(timm.list_models()))  # 1279
print(len(timm.list_models(pretrained=True)))  # 1689
print(len(timm.list_models('*efficientnetv2*', pretrained=True)))  # 21
print(timm.list_models('*efficientnetv2*', pretrained=True))

In [None]:
timm.list_models('*convnext*', pretrained=True)

In [None]:
# model = timm.create_model('tf_efficientnetv2_s.in1k',
#                           pretrained=True,
#                           cache_dir=r"E:\Git\pytorch-image-models\models")
# model = timm.create_model('resnet50d', pretrained=True,
#                           num_classes=10,
#                           cache_dir=r"E:\Git\pytorch-image-models\models")
# print(model)
# print(model.get_classifier())  # Linear(in_features=2048, out_features=10, bias=True)
# print(model.global_pool)  # SelectAdaptivePool2d(pool_type=avg, flatten=Flatten(start_dim=1, end_dim=-1))
#
# pool_types = ['avg', 'max', 'avgmax', 'catavgmax', '']
#
# for pool in pool_types:
#     model = timm.create_model('resnet50d', pretrained=True,
#                               num_classes=0, global_pool=pool,
#                               cache_dir=r"E:\Git\pytorch-image-models\models")
#     model.eval()
#     feature_output = model(torch.randn(1, 3, 224, 224))
#     print(feature_output.shape)

model = timm.create_model('resnet50d', pretrained=True,
                          num_classes=10, global_pool='catavgmax',
                          cache_dir=r"E:\Git\pytorch-image-models\models")
print(model.get_classifier())
num_in_features = model.get_classifier().in_features
print(num_in_features)
# 修改分类头
model.fc = nn.Sequential(
    nn.BatchNorm1d(num_in_features),
    nn.Linear(in_features=num_in_features, out_features=512, bias=False),
    nn.ReLU(),
    nn.BatchNorm1d(512),
    nn.Dropout(0.4),
    nn.Linear(in_features=512, out_features=10, bias=False)
)
print(model.get_classifier())
model.eval()
output = model(torch.randn(1, 3, 224, 224))
print(output.shape)

In [None]:
model = timm.create_model('resnet50d', pretrained=True, num_classes=10, cache_dir=r"E:\Git\pytorch-image-models\models")
model

In [None]:
model = timm.create_model('resnet50d', pretrained=True, num_classes=10, features_only=True, cache_dir=r"E:\Git\pytorch-image-models\models")
model

In [2]:
model = timm.create_model(
            "convnextv2_tiny.fcmae_ft_in1k",
            pretrained=True,
            num_classes=2,  # 类别数
            # features_only=True,  # 输入通道数
            cache_dir=r"E:\Git\pytorch-image-models\models"  # 权重缓存路径（可选）
        )

In [None]:
for name, param in model.named_parameters():
    print(name, param.shape)
    if 'head' not in name:
        param.requires_grad = False

In [10]:
# 记录可训练参数
trainable_params = 0
frozen_params = 0
freeze_backbone = True
freeze_layers = None
for name, param in model.named_parameters():
    print(name, param.shape)
    # 默认不冻结
    requires_grad = True

    # 如果指定了冻结层列表
    if freeze_layers is not None and any(layer in name for layer in freeze_layers):
        requires_grad = False

    # 如果全局冻结骨干网络（且未被冻结层列表覆盖）
    elif freeze_backbone and not any(head_key in name for head_key in ['head', 'classifier', 'pred']):
        requires_grad = False

    param.requires_grad = requires_grad
    # 统计参数数量
    if param.requires_grad:
        trainable_params += param.numel()
    else:
        frozen_params += param.numel()

print(f"参数统计: 可训练参数 {trainable_params:,} | 冻结参数 {frozen_params:,} | 总参数 {trainable_params + frozen_params:,}")

stem.0.weight torch.Size([96, 3, 4, 4])
stem.0.bias torch.Size([96])
stem.1.weight torch.Size([96])
stem.1.bias torch.Size([96])
stages.0.blocks.0.conv_dw.weight torch.Size([96, 1, 7, 7])
stages.0.blocks.0.conv_dw.bias torch.Size([96])
stages.0.blocks.0.norm.weight torch.Size([96])
stages.0.blocks.0.norm.bias torch.Size([96])
stages.0.blocks.0.mlp.fc1.weight torch.Size([384, 96])
stages.0.blocks.0.mlp.fc1.bias torch.Size([384])
stages.0.blocks.0.mlp.grn.weight torch.Size([384])
stages.0.blocks.0.mlp.grn.bias torch.Size([384])
stages.0.blocks.0.mlp.fc2.weight torch.Size([96, 384])
stages.0.blocks.0.mlp.fc2.bias torch.Size([96])
stages.0.blocks.1.conv_dw.weight torch.Size([96, 1, 7, 7])
stages.0.blocks.1.conv_dw.bias torch.Size([96])
stages.0.blocks.1.norm.weight torch.Size([96])
stages.0.blocks.1.norm.bias torch.Size([96])
stages.0.blocks.1.mlp.fc1.weight torch.Size([384, 96])
stages.0.blocks.1.mlp.fc1.bias torch.Size([384])
stages.0.blocks.1.mlp.grn.weight torch.Size([384])
stages.0.bl

In [9]:
model.head

NormMlpClassifierHead(
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
  (norm): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (pre_logits): Identity()
  (drop): Dropout(p=0.0, inplace=False)
  (fc): Linear(in_features=768, out_features=2, bias=True)
)

In [12]:
trainable_params = 0
frozen_params = 0
for name, param in model.head.named_parameters():
    print(name, param.shape)
    # 统计参数数量
    if param.requires_grad:
        trainable_params += param.numel()
    else:
        frozen_params += param.numel()

print(f"参数统计: 可训练参数 {trainable_params:,} | 冻结参数 {frozen_params:,} | 总参数 {trainable_params + frozen_params:,}")

norm.weight torch.Size([768])
norm.bias torch.Size([768])
fc.weight torch.Size([2, 768])
fc.bias torch.Size([2])
参数统计: 可训练参数 3,074 | 冻结参数 0 | 总参数 3,074


### data

In [None]:
# from timm.data.transforms_factory import create_transform
# print(create_transform(224,))
# print(create_transform(224, is_training=True))
# create_transform(224, is_training=True, auto_augment='rand-m9-mstd0.5')
from timm.data.auto_augment import rand_augment_transform
tfm = rand_augment_transform(config_str='rand-m9-mstd0.5', hparams={'img_mean': (124, 116, 104)})
print(tfm)

from PIL import Image

img = Image.open(r"E:\Data\TrainSet\13_HS_CaF2_cls\1029_a1b8\images\train\1\3_44.bmp")

# from timm.data.transforms import RandomResizedCropAndInterpolation
#
# tfm = RandomResizedCropAndInterpolation(size=224, interpolation='random')

import matplotlib.pyplot as plt

fig, ax = plt.subplots(2, 4, figsize=(10, 5))

for i in range(2):
    for idx, im in enumerate([tfm(img) for i in range(4)]):
        ax[i, idx].imshow(im)

fig.tight_layout()
plt.show()

### dataset

In [None]:
from timm.data import ImageDataset, create_transform
from torch.utils.data import DataLoader

def create_dataloader_iterator():
    dataset = ImageDataset(r'E:\Data\TrainSet\01_Ore_seg\images', transform=create_transform(224))
    dl = iter(DataLoader(dataset, batch_size=4))
    return dl

In [None]:
dataloader = create_dataloader_iterator()

In [None]:
inputs, classes = next(dataloader)

In [None]:
import torchvision
import numpy as np

In [None]:
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = np.clip(std * inp + mean, 0,1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

In [None]:
out = torchvision.utils.make_grid(inputs)
imshow(out, title=[x.item() for x in classes])