In [5]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [1]:
import os  
import torch  
from torch import optim, nn  
from PIL import Image  
from torch.utils import data  
from torchvision import models  
from torchvision.transforms import transforms

In [30]:
transform = transforms.Compose([  
    transforms.Resize(256),  
    transforms.CenterCrop(224),  
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
])  

class LoraDataset(data.Dataset):  
    def __init__(self, data_path="data"):  
        categories = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"]  
        self.files = []  
        self.labels = []  
        for dir in os.listdir(data_path):  
            if dir == ".ipynb_checkpoints":  # 忽略检查点文件夹
                continue
            dirname = os.path.join(data_path, dir)  
            for file in os.listdir(dirname):  
                if file == ".ipynb_checkpoints":  # 忽略检查点文件夹
                    continue
                self.files.append(os.path.join(dirname, file))  
                self.labels.append(categories.index(dir))  
  
    def __getitem__(self, item):  
        image = Image.open(self.files[item]).convert("RGB")  
        label = torch.zeros(1000, dtype=torch.float64)  
        label[self.labels[item]] = 1.  
        return transform(image), label  
  
    def __len__(self):  
        return len(self.files)

class TestLoraDataset(data.Dataset):  
    def __init__(self, data_path="test"):  
        categories = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"]  
        self.files = []  
        self.labels = []  
        for dir in os.listdir(data_path):  
            if dir == ".ipynb_checkpoints":  # 忽略检查点文件夹
                continue
            dirname = os.path.join(data_path, dir)  
            for file in os.listdir(dirname):  
                if file == ".ipynb_checkpoints":  # 忽略检查点文件夹
                    continue
                self.files.append(os.path.join(dirname, file))  
                self.labels.append(categories.index(dir))  
  
    def __getitem__(self, item):  
        image = Image.open(self.files[item]).convert("RGB")  
        label = torch.zeros(1000, dtype=torch.float64)  
        label[self.labels[item]] = 1.  
        return transform(image), label  
  
    def __len__(self):  
        return len(self.files)

In [21]:
class Lora(nn.Module):  
    def __init__(self, m, n, rank=10):  
        super().__init__()  
        self.m = m  
        self.A = nn.Parameter(torch.randn(m, rank))  
        self.B = nn.Parameter(torch.zeros(rank, n))  
  
    def forward(self, inputs):  
        inputs = inputs.view(-1, self.m)  
        return torch.mm(torch.mm(inputs, self.A), self.B)

In [22]:
# 加载底模和lora  
vgg19 = models.vgg19(models.VGG19_Weights.IMAGENET1K_V1)  
for params in vgg19.parameters():  
    params.requires_grad = False  
vgg19.eval()  
lora = Lora(224 * 224 * 3, 1000)

batch_size = 16
lr = 1e-4
# 加载数据  
lora_loader = data.DataLoader(LoraDataset(), batch_size=batch_size, shuffle=True)  
# 加载优化器  
optimizer = optim.Adam(lora.parameters(), lr=lr)  
# 定义损失  
loss_fn = nn.CrossEntropyLoss()  
#  训练轮次
epochs = 10 
# 训练  
for epoch in range(epochs):  
    for image, label in lora_loader:  
        # 正向传播  
        pred = vgg19(image) + lora(image)  
        loss = loss_fn(pred, label)  
        # 反向传播  
        loss.backward()  
        # 更新参数  
        optimizer.step()  
        optimizer.zero_grad()  
        print(f"loss: {loss.item()}")

loss: 0.0007559689621909153
loss: 0.0003224081709352807
loss: 0.00014359283219770683
loss: 6.724195399480475e-05
loss: 3.321674641180531e-05
loss: 1.7324294863859297e-05
loss: 9.417301801780317e-06
loss: 5.4040791042098135e-06
loss: 3.258360493418877e-06
loss: 2.0265475815980003e-06


In [33]:
# # 测试  
# for image, _ in lora_loader:  
#     pred = vgg19(image) + lora(image)  
#     print(pred)
#     idx = torch.argmax(pred, dim=1).item()  
#     category = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"][idx]  
#     print(category)
# torch.save(lora.state_dict(), 'lora.pth')

In [32]:
test_lora_loader = data.DataLoader(TestLoraDataset(), batch_size=batch_size, shuffle=True) 

for image, _ in test_lora_loader:  
    pred = vgg19(image) + lora(image)  
    
    # 找到每个样本的最大值索引
    idx = torch.argmax(pred, dim=1)  
    # 遍历每个样本的索引
    for i in range(len(idx)):
        category = models.VGG19_Weights.IMAGENET1K_V1.value.meta["categories"][idx[i].item()]  
        print(category)

torch.save(lora.state_dict(), 'lora.pth')

goldfish
goldfish
goldfish


In [1]:
import os
import torch
from torchvision import models, transforms
from PIL import Image
import requests

# 加载模型
vgg19 = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1)
vgg19.eval()  # 设置为评估模式

# 定义图像预处理
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 读取标签
LABELS_URL = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"
labels = requests.get(LABELS_URL).json()

# 推理并输出结果
test_folder = 'test/goldfish'
for filename in os.listdir(test_folder):
    if filename.endswith(('.jpg', '.png', '.jpeg')):
        img_path = os.path.join(test_folder, filename)
        image = Image.open(img_path)
        image = preprocess(image).unsqueeze(0)

        with torch.no_grad():
            output = vgg19(image)
            _, predicted = torch.max(output, 1)
            print(f"{filename}: {labels[predicted.item()]}")


2.png: goldfish
3.png: goldfish
1.png: goldfish
