In [1]:
from torchvision.datasets import ImageFolder
import numpy as np
import torch


# 设置随机数生成器种子
seed = 24
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

# 数据集根目录
data_root = './data/dataset_new/tongue_color'
dataset = ImageFolder(root=data_root)

class_names = dataset.classes

print(len(dataset))
print(dataset.class_to_idx)
print(dataset[0])

93
{'an_hong': 0, 'dan_bai': 1, 'dan_hong': 2, 'jiang_she': 3, 'qing_zi': 4}
(<PIL.Image.Image image mode=RGB size=1240x1411 at 0x273CD5B3FD0>, 0)


In [2]:
from torchvision.transforms import transforms
from torch.utils.data import DataLoader


dataset_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

data_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
data_loader.dataset.transform = dataset_transform
# 获取所有图像数据
all_images = None

for inputs, _ in data_loader:
    all_images = inputs

# 计算数据集的均值和标准差
mean = torch.mean(all_images, dim=(0, 2, 3))
std = torch.std(all_images, dim=(0, 2, 3))

print(f"均值：{mean}")
print(f"标准差：{std}")

均值：tensor([0.4792, 0.4043, 0.4172])
标准差：tensor([0.2912, 0.2606, 0.2681])


In [3]:
from torch.utils.data import random_split
from imblearn.over_sampling import RandomOverSampler
from torch.utils.data import TensorDataset


class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)

        for _ in range(self.n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - self.length // 2, 0, h)
            y2 = np.clip(y + self.length // 2, 0, h)
            x1 = np.clip(x - self.length // 2, 0, w)
            x2 = np.clip(x + self.length // 2, 0, w)

            mask[y1:y2, x1:x2] = 0

        mask = torch.from_numpy(mask)
        img *= mask.unsqueeze(0)

        return img


# 数据增强和预处理
data_augmentation_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.RandomApply(transforms.GaussianBlur(3, sigma=(0.1, 2.0)), 0.5),
    Cutout(n_holes=8, length=32),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


# # 计算每个类别的样本数量
# class_counts = torch.bincount(torch.tensor(dataset.targets))

# # 计算每个类别的权重
# class_weights = 1 / class_counts.float()


# 划分数据集
train_size = int(4 / 6.0 * len(dataset))
val_size = int(1 / 6.0 * len(dataset))
test_size = len(dataset) - train_size - val_size
remaining_size = len(dataset) - train_size
train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size])

train_dataset.dataset.transform = data_augmentation_transform
val_dataset.dataset.transform = transform
test_dataset.dataset.transform = transform

# 获取训练集的特征和标签
X_train = []
y_train = []
for inputs, labels in train_dataset:
    X_train.append(inputs.numpy())  # 将张量转换为 numpy 数组
    y_train.append(labels)
X_train = np.array(X_train)
y_train = np.array(y_train)

# 创建 RandomOverSampler 实例
ros = RandomOverSampler(random_state=24)

# 进行过采样
X_resampled, y_resampled = ros.fit_resample(
    X_train.reshape(X_train.shape[0], -1), y_train)

# 将特征转换回张量形式
X_resampled = torch.tensor(X_resampled.reshape(-1, *inputs.shape))

# 创建过采样后的训练集
oversampled_train_dataset = TensorDataset(
    X_resampled, torch.tensor(y_resampled))

# 创建 DataLoader
train_loader = DataLoader(oversampled_train_dataset,
                          batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)
test_loader = DataLoader(test_dataset, batch_size=8)

In [4]:
# 输出数据集大小
print("过采样后的训练集大小:", len(oversampled_train_dataset))
print("验证集大小:", len(val_dataset))
print("测试集大小:", len(test_dataset))

for inputs, labels in train_loader:
    print(inputs.shape, labels.shape)
    print(inputs[0])
    break

for inputs, labels in val_loader:
    print(inputs.shape, labels.shape)
    print(inputs[0])
    break

id2label = {id: label for id, label in enumerate(
    train_dataset.dataset.classes)}
label2id = {label: id for id, label in id2label.items()}

过采样后的训练集大小: 175
验证集大小: 15
测试集大小: 16
torch.Size([8, 3, 224, 224]) torch.Size([8])
tensor([[[-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455],
         [-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455],
         [-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455],
         ...,
         [-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455],
         [-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455],
         [-1.6455, -1.6455, -1.6455,  ..., -1.6455, -1.6455, -1.6455]],

        [[-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513],
         [-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513],
         [-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513],
         ...,
         [-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513],
         [-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513],
         [-1.5513, -1.5513, -1.5513,  ..., -1.5513, -1.5513, -1.5513]],

        [[-1.5559, -1.5559, -1.5559

In [5]:
from transformers import AutoFeatureExtractor


# 加载 Vision Transformer 的特征提取器
model_name = 'microsoft/swin-base-patch4-window7-224'
feature_extractor = AutoFeatureExtractor.from_pretrained(
    model_name,
    do_normalize=False,
    do_resize=False,
    do_rescale=False,
)
# feature_extractor.image_mean = mean
# feature_extractor.image_std = std
feature_extractor

preprocessor_config.json:   0%|          | 0.00/255 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


ViTFeatureExtractor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "do_normalize": false,
  "do_rescale": false,
  "do_resize": false,
  "image_mean": [
    0.485,
    0.456,
    0.406
  ],
  "image_processor_type": "ViTFeatureExtractor",
  "image_std": [
    0.229,
    0.224,
    0.225
  ],
  "resample": 3,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [7]:
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


def custom_collate(batch):

    # 将样本数据列表拆分为输入和标签
    images, labels = list(zip(*batch))
    inputs = feature_extractor(images, return_tensors='pt')

    inputs['labels'] = torch.tensor(labels)
    inputs['pixel_values'] = torch.stack(
        [input for input in inputs["pixel_values"]])
    return inputs


train_data = MyDataset(train_dataset)
val_data = MyDataset(val_dataset)
test_data = MyDataset(test_dataset)

In [12]:
from transformers import SwinForImageClassification

model = SwinForImageClassification.from_pretrained(model_name, num_labels=len(dataset.classes), id2label=id2label, label2id=label2id, ignore_mismatched_sizes = True)

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-base-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 1024]) in the checkpoint and torch.Size([5, 1024]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"tongue-color-classification",
    remove_unused_columns=False,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=30,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    warmup_ratio=0.1,
    logging_steps=10,
    seed=seed,
)

In [21]:
from sklearn.metrics import accuracy_score


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return dict(accuracy=accuracy_score(predictions, labels))
    # logits, labels = eval_pred
    # predictions = logits.argsort(axis=-1)[:, -5:]  # 获取top-5预测结果
    # # 将标签扩展为与预测结果相同的形状
    # labels_expanded = labels.reshape(-1, 1).repeat(5, axis=1)
    # correct_predictions = (predictions == labels_expanded)  # 判断预测结果是否包含正确标签
    # top5_accuracy = correct_predictions.any(axis=-1).mean()  # 计算top-5准确率
    # return {"accuracy": top5_accuracy}

In [22]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=custom_collate,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [24]:
# Train and save results
train_results = trainer.train()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

# Evaluate on validation set
metrics = trainer.evaluate(val_data)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.984650135040283, 'eval_accuracy': 0.6, 'eval_runtime': 0.3825, 'eval_samples_per_second': 39.218, 'eval_steps_per_second': 5.229, 'epoch': 1.0}
{'loss': 0.0001, 'grad_norm': 0.002496622037142515, 'learning_rate': 2.0833333333333336e-05, 'epoch': 1.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.169843912124634, 'eval_accuracy': 0.6, 'eval_runtime': 0.4034, 'eval_samples_per_second': 37.182, 'eval_steps_per_second': 4.958, 'epoch': 2.0}
{'loss': 0.0, 'grad_norm': 0.002445352729409933, 'learning_rate': 4.166666666666667e-05, 'epoch': 2.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.460874319076538, 'eval_accuracy': 0.6, 'eval_runtime': 0.3939, 'eval_samples_per_second': 38.082, 'eval_steps_per_second': 5.078, 'epoch': 3.0}
{'loss': 0.0, 'grad_norm': 0.00035053511965088546, 'learning_rate': 4.8611111111111115e-05, 'epoch': 3.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.642017126083374, 'eval_accuracy': 0.6, 'eval_runtime': 0.3987, 'eval_samples_per_second': 37.626, 'eval_steps_per_second': 5.017, 'epoch': 4.0}
{'loss': 0.0, 'grad_norm': 0.00230101915076375, 'learning_rate': 4.62962962962963e-05, 'epoch': 5.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.7000434398651123, 'eval_accuracy': 0.6, 'eval_runtime': 0.397, 'eval_samples_per_second': 37.78, 'eval_steps_per_second': 5.037, 'epoch': 5.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.855368137359619, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3964, 'eval_samples_per_second': 37.838, 'eval_steps_per_second': 5.045, 'epoch': 6.0}
{'loss': 0.0001, 'grad_norm': 0.38988572359085083, 'learning_rate': 4.3981481481481486e-05, 'epoch': 6.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.713486433029175, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.4009, 'eval_samples_per_second': 37.417, 'eval_steps_per_second': 4.989, 'epoch': 7.0}
{'loss': 0.0, 'grad_norm': 0.008334229700267315, 'learning_rate': 4.166666666666667e-05, 'epoch': 7.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.6170756816864014, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3973, 'eval_samples_per_second': 37.751, 'eval_steps_per_second': 5.033, 'epoch': 8.0}
{'loss': 0.3204, 'grad_norm': 37.25185012817383, 'learning_rate': 3.935185185185186e-05, 'epoch': 8.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.732788562774658, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3958, 'eval_samples_per_second': 37.895, 'eval_steps_per_second': 5.053, 'epoch': 9.0}
{'loss': 0.0056, 'grad_norm': 0.016994157806038857, 'learning_rate': 3.7037037037037037e-05, 'epoch': 10.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.5037777423858643, 'eval_accuracy': 0.6, 'eval_runtime': 0.3938, 'eval_samples_per_second': 38.088, 'eval_steps_per_second': 5.078, 'epoch': 10.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.6328327655792236, 'eval_accuracy': 0.6, 'eval_runtime': 0.3964, 'eval_samples_per_second': 37.84, 'eval_steps_per_second': 5.045, 'epoch': 11.0}
{'loss': 0.0003, 'grad_norm': 0.021914666518568993, 'learning_rate': 3.472222222222222e-05, 'epoch': 11.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 2.9033052921295166, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 0.3953, 'eval_samples_per_second': 37.942, 'eval_steps_per_second': 5.059, 'epoch': 12.0}
{'loss': 0.0586, 'grad_norm': 0.027946218848228455, 'learning_rate': 3.240740740740741e-05, 'epoch': 12.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.436224937438965, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3956, 'eval_samples_per_second': 37.919, 'eval_steps_per_second': 5.056, 'epoch': 13.0}
{'loss': 0.0, 'grad_norm': 0.00024005823070183396, 'learning_rate': 3.0092592592592593e-05, 'epoch': 13.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.7063252925872803, 'eval_accuracy': 0.6, 'eval_runtime': 0.3944, 'eval_samples_per_second': 38.029, 'eval_steps_per_second': 5.07, 'epoch': 14.0}
{'loss': 0.0001, 'grad_norm': 0.00014966308663133532, 'learning_rate': 2.777777777777778e-05, 'epoch': 15.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.566612958908081, 'eval_accuracy': 0.6, 'eval_runtime': 0.3926, 'eval_samples_per_second': 38.211, 'eval_steps_per_second': 5.095, 'epoch': 15.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.1291308403015137, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3936, 'eval_samples_per_second': 38.109, 'eval_steps_per_second': 5.081, 'epoch': 16.0}
{'loss': 0.0002, 'grad_norm': 0.9121088981628418, 'learning_rate': 2.5462962962962965e-05, 'epoch': 16.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.932713985443115, 'eval_accuracy': 0.6, 'eval_runtime': 0.3987, 'eval_samples_per_second': 37.619, 'eval_steps_per_second': 5.016, 'epoch': 17.0}
{'loss': 0.0001, 'grad_norm': 0.001367276650853455, 'learning_rate': 2.314814814814815e-05, 'epoch': 17.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 5.678503036499023, 'eval_accuracy': 0.6, 'eval_runtime': 0.3948, 'eval_samples_per_second': 37.991, 'eval_steps_per_second': 5.065, 'epoch': 18.0}
{'loss': 0.0002, 'grad_norm': 0.0028825136832892895, 'learning_rate': 2.0833333333333336e-05, 'epoch': 18.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 6.061806678771973, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3986, 'eval_samples_per_second': 37.63, 'eval_steps_per_second': 5.017, 'epoch': 19.0}
{'loss': 0.0003, 'grad_norm': 0.0509231761097908, 'learning_rate': 1.8518518518518518e-05, 'epoch': 20.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 10.153315544128418, 'eval_accuracy': 0.3333333333333333, 'eval_runtime': 0.4004, 'eval_samples_per_second': 37.46, 'eval_steps_per_second': 4.995, 'epoch': 20.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.41769552230835, 'eval_accuracy': 0.4666666666666667, 'eval_runtime': 0.3968, 'eval_samples_per_second': 37.805, 'eval_steps_per_second': 5.041, 'epoch': 21.0}
{'loss': 0.2123, 'grad_norm': 0.003919141832739115, 'learning_rate': 1.6203703703703704e-05, 'epoch': 21.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.060247421264648, 'eval_accuracy': 0.6, 'eval_runtime': 0.4074, 'eval_samples_per_second': 36.818, 'eval_steps_per_second': 4.909, 'epoch': 22.0}
{'loss': 0.0, 'grad_norm': 7.734729297226295e-05, 'learning_rate': 1.388888888888889e-05, 'epoch': 22.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.025112628936768, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3978, 'eval_samples_per_second': 37.706, 'eval_steps_per_second': 5.027, 'epoch': 23.0}
{'loss': 0.0, 'grad_norm': 8.978115511126816e-05, 'learning_rate': 1.1574074074074075e-05, 'epoch': 23.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 4.005408763885498, 'eval_accuracy': 0.6, 'eval_runtime': 0.3953, 'eval_samples_per_second': 37.948, 'eval_steps_per_second': 5.06, 'epoch': 24.0}
{'loss': 0.0, 'grad_norm': 1.4836569789622445e-05, 'learning_rate': 9.259259259259259e-06, 'epoch': 25.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.9733364582061768, 'eval_accuracy': 0.6, 'eval_runtime': 0.395, 'eval_samples_per_second': 37.976, 'eval_steps_per_second': 5.064, 'epoch': 25.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.961927652359009, 'eval_accuracy': 0.6, 'eval_runtime': 0.392, 'eval_samples_per_second': 38.262, 'eval_steps_per_second': 5.102, 'epoch': 26.0}
{'loss': 0.0, 'grad_norm': 0.00021127311629243195, 'learning_rate': 6.944444444444445e-06, 'epoch': 26.25}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.7537031173706055, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3963, 'eval_samples_per_second': 37.855, 'eval_steps_per_second': 5.047, 'epoch': 27.0}
{'loss': 0.0003, 'grad_norm': 0.00020007290004286915, 'learning_rate': 4.6296296296296296e-06, 'epoch': 27.5}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.662189245223999, 'eval_accuracy': 0.6, 'eval_runtime': 0.3981, 'eval_samples_per_second': 37.674, 'eval_steps_per_second': 5.023, 'epoch': 28.0}
{'loss': 0.0, 'grad_norm': 0.00011119836563011631, 'learning_rate': 2.3148148148148148e-06, 'epoch': 28.75}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.6475675106048584, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3942, 'eval_samples_per_second': 38.052, 'eval_steps_per_second': 5.074, 'epoch': 29.0}
{'loss': 0.0, 'grad_norm': 5.3665604355046526e-05, 'learning_rate': 0.0, 'epoch': 30.0}


  0%|          | 0/2 [00:00<?, ?it/s]

{'eval_loss': 3.645921230316162, 'eval_accuracy': 0.5333333333333333, 'eval_runtime': 0.3935, 'eval_samples_per_second': 38.123, 'eval_steps_per_second': 5.083, 'epoch': 30.0}
{'train_runtime': 109.9758, 'train_samples_per_second': 16.913, 'train_steps_per_second': 2.182, 'train_loss': 0.024950993443849257, 'epoch': 30.0}
***** train metrics *****
  epoch                    =       30.0
  train_loss               =      0.025
  train_runtime            = 0:01:49.97
  train_samples_per_second =     16.913
  train_steps_per_second   =      2.182


  0%|          | 0/2 [00:00<?, ?it/s]

***** eval metrics *****
  epoch                   =       30.0
  eval_accuracy           =     0.6667
  eval_loss               =     3.8554
  eval_runtime            = 0:00:00.40
  eval_samples_per_second =     37.261
  eval_steps_per_second   =      4.968


In [None]:
# from timm.models import create_model

# Swin = create_model('swin_large_patch4_window7_224_in22k',pretrained=True)

  model = create_fn(


model.safetensors:   0%|          | 0.00/916M [00:00<?, ?B/s]

KeyboardInterrupt: 

In [25]:
import timm

model_names = timm.list_models(pretrained=True)
print(model_names)

['bat_resnext26ts.ch_in1k', 'beit_base_patch16_224.in22k_ft_in22k', 'beit_base_patch16_224.in22k_ft_in22k_in1k', 'beit_base_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_224.in22k_ft_in22k', 'beit_large_patch16_224.in22k_ft_in22k_in1k', 'beit_large_patch16_384.in22k_ft_in22k_in1k', 'beit_large_patch16_512.in22k_ft_in22k_in1k', 'beitv2_base_patch16_224.in1k_ft_in1k', 'beitv2_base_patch16_224.in1k_ft_in22k', 'beitv2_base_patch16_224.in1k_ft_in22k_in1k', 'beitv2_large_patch16_224.in1k_ft_in1k', 'beitv2_large_patch16_224.in1k_ft_in22k', 'beitv2_large_patch16_224.in1k_ft_in22k_in1k', 'botnet26t_256.c1_in1k', 'caformer_b36.sail_in1k', 'caformer_b36.sail_in1k_384', 'caformer_b36.sail_in22k', 'caformer_b36.sail_in22k_ft_in1k', 'caformer_b36.sail_in22k_ft_in1k_384', 'caformer_m36.sail_in1k', 'caformer_m36.sail_in1k_384', 'caformer_m36.sail_in22k', 'caformer_m36.sail_in22k_ft_in1k', 'caformer_m36.sail_in22k_ft_in1k_384', 'caformer_s18.sail_in1k', 'caformer_s18.sail_in1k_384', 'caformer_s