In [6]:
# Load the data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch

ROOT_DIR = "dataset//"
BATCH_SIZE = 64

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
    transforms.RandomErasing(p = 0.25),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

train_ds = datasets.ImageFolder(f"{ROOT_DIR}train_exp", transform = train_transforms)
val_ds = datasets.ImageFolder(f"{ROOT_DIR}val_exp", transform = val_transforms)
print("训练集长度：", len(train_ds))
print("测试集长度：", len(val_ds))

train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)
val_dl = DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)

训练集长度： 236
测试集长度： 49


In [7]:
# Take a look at the dataset
from torch.utils.tensorboard import SummaryWriter

print(f"Number of classes detected: {len(train_ds.classes)}, details:")
print(train_ds.class_to_idx)

"""
writer = SummaryWriter("logs")
for i in range(10):
    train_img, target = train_ds[i]
    writer.add_image("train_ds", train_img, i)
    val_img, target = val_ds[i]
    writer.add_image("val_ds", val_img, i)

step = 0
for data in train_dl:
    imgs, target = data
    writer.add_images("train_dl", imgs, step)
    step += 1
    if step == 9:
        break
step = 0
for data in val_dl:
    imgs, target = data
    writer.add_images("val_dl", imgs, step)
    step += 1
    if step == 9:
        break
writer.close()
"""

Number of classes detected: 5, details:
{'Aedes_aegypti': 0, 'Culex_pipiens': 1, 'Musca_domestica': 2, 'Periplaneta_americana': 3, 'Rattus_rattus': 4}


'\nwriter = SummaryWriter("logs")\nfor i in range(10):\n    train_img, target = train_ds[i]\n    writer.add_image("train_ds", train_img, i)\n    val_img, target = val_ds[i]\n    writer.add_image("val_ds", val_img, i)\n\nstep = 0\nfor data in train_dl:\n    imgs, target = data\n    writer.add_images("train_dl", imgs, step)\n    step += 1\n    if step == 9:\n        break\nstep = 0\nfor data in val_dl:\n    imgs, target = data\n    writer.add_images("val_dl", imgs, step)\n    step += 1\n    if step == 9:\n        break\nwriter.close()\n'

In [8]:
# Load the model
import torch, torchvision

MODEL_PATH = "convnext_tiny-983f1562.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = torchvision.models.convnext_tiny(weights = False)
model.load_state_dict(torch.load(MODEL_PATH))

# model = torch.load(MODEL_PATH)

model.add_module("new_head", torch.nn.Linear(1000 ,len(train_ds.classes)))
model.to(DEVICE)

print(f"Running on: {DEVICE}")
print(model)

Running on: cuda
ConvNeXt(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
    )
    (1): Sequential(
      (0): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Linear(in_features=96, out_features=384, bias=True)
          (4): GELU(approximate='none')
          (5): Linear(in_features=384, out_features=96, bias=True)
          (6): Permute()
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): CNBlock(
        (block): Sequential(
          (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
          (3): Li

In [10]:
import time, torch
from tqdm import tqdm

LR = 1e-3
EPOCH = 10

loss_fn = torch.nn.CrossEntropyLoss().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

for epoch in range(EPOCH):
    t0 = time.time()
    # ---------- train ----------
    train_loss, train_correct, train_total = 0.0, 0, 0
    model.train()
    for imgs, targets in tqdm(train_dl, desc=f"Train {epoch+1}/{EPOCH}", leave=False):
        imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss    = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss_fn(outputs, targets).item() * imgs.size(0)
        train_correct  += (outputs.argmax(1) == targets).sum().item()
        train_total    += targets.size(0)

    # ---------- validate ----------
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for imgs, targets in tqdm(val_dl, desc="Val", leave=False):
            imgs, targets = imgs.to(DEVICE), targets.to(DEVICE)
            outputs = model(imgs)

            val_loss += loss_fn(outputs, targets).item() * imgs.size(0)
            correct  += (outputs.argmax(1) == targets).sum().item()
            total    += targets.size(0)

    print(f"Epoch {epoch+1}: \n"
          f"train_loss = {train_loss/train_total:.4f}, "
          f"train_acc = {train_correct/train_total:.4%}, \n"
          f"val_loss = {val_loss/total:.4f}, "
          f"val_acc = {correct/total:.4%}, \n"
          f"time = {time.time()-t0:.1f}s")

torch.save(model.state_dict(), "conv_tiny_torch.pth")

                                                         

Epoch 1: 
train_loss = 3.0248, train_acc = 18.6441%, 
val_loss = 2.8073, val_acc = 30.6122%, 
time = 73.6s


                                                         

Epoch 2: 
train_loss = 3.1170, train_acc = 24.1525%, 
val_loss = 2.2222, val_acc = 34.6939%, 
time = 73.9s


                                                         

Epoch 3: 
train_loss = 2.5796, train_acc = 21.1864%, 
val_loss = 1.8501, val_acc = 38.7755%, 
time = 71.4s


                                                         

Epoch 4: 
train_loss = 2.0175, train_acc = 29.2373%, 
val_loss = 1.6223, val_acc = 34.6939%, 
time = 71.1s


                                                         

Epoch 5: 
train_loss = 2.1640, train_acc = 25.8475%, 
val_loss = 2.0606, val_acc = 34.6939%, 
time = 70.8s


                                                         

Epoch 6: 
train_loss = 1.8185, train_acc = 36.8644%, 
val_loss = 1.3477, val_acc = 40.8163%, 
time = 70.7s


                                                         

Epoch 7: 
train_loss = 1.5395, train_acc = 37.7119%, 
val_loss = 1.2933, val_acc = 51.0204%, 
time = 70.8s


                                                         

Epoch 8: 
train_loss = 1.9364, train_acc = 32.2034%, 
val_loss = 1.2464, val_acc = 57.1429%, 
time = 70.7s


                                                         

Epoch 9: 
train_loss = 1.4435, train_acc = 44.0678%, 
val_loss = 1.2302, val_acc = 42.8571%, 
time = 70.6s


                                                          

Epoch 10: 
train_loss = 1.5757, train_acc = 33.4746%, 
val_loss = 1.4437, val_acc = 44.8980%, 
time = 70.6s


