In [19]:
# Load the data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch

ROOT_DIR = "dataset//"
BATCH_SIZE = 64

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
    transforms.RandomErasing(p = 0.25),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

train_ds = datasets.ImageFolder(f"{ROOT_DIR}train", transform = train_transforms)
val_ds = datasets.ImageFolder(f"{ROOT_DIR}val", transform = val_transforms)
print("训练集长度：", len(train_ds))
print("测试集长度：", len(val_ds))

train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)
val_dl = DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)

训练集长度： 2150
测试集长度： 493


In [2]:
# Take a look at the dataset
from torch.utils.tensorboard import SummaryWriter

print(f"Number of classes detected: {len(train_ds.classes)}, details:")
print(train_ds.class_to_idx)

writer = SummaryWriter("logs")
for i in range(10):
    train_img, target = train_ds[i]
    writer.add_image("train_ds", train_img, i)
    val_img, target = val_ds[i]
    writer.add_image("val_ds", val_img, i)

step = 0
for data in train_dl:
    imgs, target = data
    writer.add_images("train_dl", imgs, step)
    step += 1
    if step == 9:
        break
step = 0
for data in val_dl:
    imgs, target = data
    writer.add_images("val_dl", imgs, step)
    step += 1
    if step == 9:
        break
writer.close()

Number of classes detected: 45, details:
{'Aedes_aegypti': 0, 'Aedes_alboannulatus': 1, 'Aedes_albopictus': 2, 'Aedes_canadensis': 3, 'Aedes_caspius': 4, 'Aedes_geniculatus': 5, 'Aedes_japonicus': 6, 'Aedes_notoscriptus': 7, 'Aedes_sollicitans': 8, 'Aedes_taeniorhynchus': 9, 'Aedes_triseriatus': 10, 'Aedes_trivittatus': 11, 'Aedes_vexans': 12, 'Anopheles_punctipennis': 13, 'Anopheles_quadrimaculatus': 14, 'Aptera_fusca': 15, 'Armigeres_subalbatus': 16, 'Blatta_orientalis': 17, 'Coquillettidia_perturbans': 18, 'Culex_pipiens': 19, 'Culex_quinquefasciatus': 20, 'Culex_tarsalis': 21, 'Culiseta_annulata': 22, 'Culiseta_incidens': 23, 'Culiseta_inornata': 24, 'Culiseta_longiareolata': 25, 'Mesembrina_meridiana': 26, 'Mus_musculus': 27, 'Musca_domestica': 28, 'Panchlora_nivea': 29, 'Periplaneta_americana': 30, 'Periplaneta_australasiae': 31, 'Periplaneta_fuliginosa': 32, 'Pseudomops_septentrionalis': 33, 'Psorophora_ciliata': 34, 'Psorophora_columbiae': 35, 'Psorophora_cyanescens': 36, 'Psor

In [None]:
# Load the model
import torch, torchvision

MODEL_PATH = "convnext_tiny-983f1562.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.convnext_tiny(weights = False)
model.load_state_dict(torch.load(MODEL_PATH))

# model = torch.load(MODEL_PATH)

model.add_module("new_head", torch.nn.Linear(1000 ,len(train_ds.classes)))
model.to(DEVICE)

print(f"Running on: {DEVICE}")
print(model)



ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=

In [None]:
import time, torch
from tqdm import tqdm

LR = 1e-3
EPOCH = 10

loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

for epoch in range(EPOCH):
    t0 = time.time()
    # ---------- train ----------
    train_loss, train_correct, train_total = 0.0, 0, 0
    model.train()
    for imgs, targets in tqdm(train_dl, desc=f"Train {epoch+1}/{EPOCH}", leave=False):
        imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss    = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss_fn(outputs, targets).item() * imgs.size(0)
        train_correct  += (outputs.argmax(1) == targets).sum().item()
        train_total    += targets.size(0)

    # ---------- validate ----------
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, targets in tqdm(val_dl, desc="Val", leave=False):
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            outputs = model(imgs)

            val_loss += loss_fn(outputs, targets).item() * imgs.size(0)
            correct  += (outputs.argmax(1) == targets).sum().item()
            total    += targets.size(0)

    print(f"Epoch {epoch+1}: \n"
          f"train_loss = {train_loss/train_total:.4f}"
          f"val_acc = {train_correct/train_total:.4%}, \n"
          f"val_loss = {val_loss/total:.4f}, "
          f"val_acc = {correct/total:.4%}, \n"
          f"time = {time.time()-t0:.1f}s")

  0%|          | 0/10 [00:00<?, ?it/s]

第1次训练，Loss: 7.840580940246582
第2次训练，Loss: 7.082515716552734
第3次训练，Loss: 6.528244495391846
第4次训练，Loss: 6.728500843048096
第5次训练，Loss: 5.981889724731445
第6次训练，Loss: 5.965354919433594
第7次训练，Loss: 5.704470157623291
第8次训练，Loss: 5.78027868270874
第9次训练，Loss: 5.678564071655273


  0%|          | 0/10 [01:02<?, ?it/s]


KeyboardInterrupt: 