In [None]:
# Load the data
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch

ROOT_DIR = "dataset//"
BATCH_SIZE = 64

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
    transforms.RandomErasing(p = 0.25),
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

train_ds = datasets.ImageFolder(f"{ROOT_DIR}train", transform = train_transforms)
val_ds = datasets.ImageFolder(f"{ROOT_DIR}val", transform = val_transforms)
print("训练集长度：", len(train_ds))
print("测试集长度：", len(val_ds))

train_dl = DataLoader(train_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)
val_dl = DataLoader(val_ds, batch_size = BATCH_SIZE, shuffle = True, drop_last = False)

In [None]:
# Take a look at the dataset
from torch.utils.tensorboard import SummaryWriter

print(f"Number of classes detected: {len(train_ds.classes)}, details:")
print(train_ds.class_to_idx)

writer = SummaryWriter("logs")
for i in range(10):
    train_img, target = train_ds[i]
    writer.add_image("train_ds", train_img, i)
    val_img, target = val_ds[i]
    writer.add_image("val_ds", val_img, i)

step = 0
for data in train_dl:
    imgs, target = data
    writer.add_images("train_dl", imgs, step)
    step += 1
    if step == 9:
        break
step = 0
for data in val_dl:
    imgs, target = data
    writer.add_images("val_dl", imgs, step)
    step += 1
    if step == 9:
        break
writer.close()

{'Aedes_aegypti': 0, 'Aedes_alboannulatus': 1, 'Aedes_albopictus': 2, 'Aedes_canadensis': 3, 'Aedes_caspius': 4, 'Aedes_geniculatus': 5, 'Aedes_japonicus': 6, 'Aedes_notoscriptus': 7, 'Aedes_sollicitans': 8, 'Aedes_taeniorhynchus': 9, 'Aedes_triseriatus': 10, 'Aedes_trivittatus': 11, 'Aedes_vexans': 12, 'Anopheles_punctipennis': 13, 'Anopheles_quadrimaculatus': 14, 'Aptera_fusca': 15, 'Armigeres_subalbatus': 16, 'Blatta_orientalis': 17, 'Coquillettidia_perturbans': 18, 'Culex_pipiens': 19, 'Culex_quinquefasciatus': 20, 'Culex_tarsalis': 21, 'Culiseta_annulata': 22, 'Culiseta_incidens': 23, 'Culiseta_inornata': 24, 'Culiseta_longiareolata': 25, 'Mesembrina_meridiana': 26, 'Mus_musculus': 27, 'Musca_domestica': 28, 'Panchlora_nivea': 29, 'Periplaneta_americana': 30, 'Periplaneta_australasiae': 31, 'Periplaneta_fuliginosa': 32, 'Pseudomops_septentrionalis': 33, 'Psorophora_ciliata': 34, 'Psorophora_columbiae': 35, 'Psorophora_cyanescens': 36, 'Psorophora_ferox': 37, 'Pycnoscelus_surinamen

KeyboardInterrupt: 

In [None]:
# Load the model
MODEL_PATH = ""

model = torch.load(MODEL_PATH)

# model.add_model("new_head", torch.nn.Linear(,len(train_ds.classes)))

print(model)

In [None]:
# 训练
from tqdm.rich import trange
import time

LR = 1e-3
EPOCH = 10

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

for i in trange(EPOCH):
    t0 = time.perf_counter()
    # train
    train_step = 0
    for data in train_dl:
        imgs, targets = data
        outputs = model(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_step += 1
        print(f"第{train_step}次训练，Loss: {loss.item()}")
    
    # validate
    val_step = 0
    total_val_loss = 0
    with torch.no_grad():
        for data in val_dl:
            imgs, targets = data
            outputs = model(imgs)
            loss = loss_fn(outputs, targets)
            total_val_loss += loss
        print(f"第{i+1}轮训练, Val_loss = {total_val_loss / val_step}")