In [29]:
# In[1]: 导入所需的库
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50, ResNet50_Weights
from transformers import GPT2Tokenizer, GPT2Model
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from torch.cuda.amp import GradScaler
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10, CIFAR100

In [30]:
# In[2]: 定义图像编码器
class ImageEncoder(nn.Module):
    def __init__(self, mode='finetune'):
        super().__init__()
        weights = ResNet50_Weights.IMAGENET1K_V1
        self.resnet = resnet50(weights=weights)
        
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-2])
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1))

    def forward(self, x):
        x = self.resnet(x)
        x = self.adaptive_pool(x)
        x = torch.reshape(x, (x.shape[0], x.shape[1]))
        return x

In [31]:
# In[3]: 定义文本编码器
class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.tokenizer = GPT2Tokenizer.from_pretrained('openai-community/gpt2')
        self.model = GPT2Model.from_pretrained('openai-community/gpt2')
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def forward(self, texts):
        inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, -1, :]

In [32]:
# In[4]: 定义CLIP模型
class CLIP(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        
        # 投影层
        self.image_projection = nn.Linear(2048, embedding_dim)
        self.text_projection = nn.Linear(768, embedding_dim)
        
        # 额外的线性层
        self.image_layer1 = nn.Linear(embedding_dim, embedding_dim)
        self.text_layer1 = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, image, text):
        image_embedding = self.image_encoder(image)
        text_embedding = self.text_encoder(text)
        
        image_embedding = self.image_projection(image_embedding)
        text_embedding = self.text_projection(text_embedding)
        
        image_embedding = self.image_layer1(image_embedding)
        text_embedding = self.text_layer1(text_embedding)
        
        # 标准化embeddings
        image_embedding = F.normalize(image_embedding, p=2, dim=-1)
        text_embedding = F.normalize(text_embedding, p=2, dim=-1)
        
        return image_embedding, text_embedding

In [33]:
# In[5]: 设置设备和数据集
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 加载CIFAR10数据集
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# In[6]: 初始化模型
model = CLIP(embedding_dim=512).to(device)

Files already downloaded and verified
Files already downloaded and verified


In [34]:
# In[2]: 实现Triplet Loss
class TripletLoss(nn.Module):
    def __init__(self, margin=0.2):
        super().__init__()
        self.margin = margin
        
    def forward(self, img_emb, text_emb):
        # 计算正例对的相似度
        img_text_similarity = torch.matmul(img_emb, text_emb.t()).diag()
        
        # 创建负例对（通过循环移位）
        n = img_emb.shape[0]
        original_list = list(range(n))
        shifted_list = original_list[1:] + [original_list[0]]
        shuffled_image = img_emb[shifted_list]
        shuffled_text = text_emb[shifted_list]
        
        # 计算与负例的相似度
        neg_sim_img = torch.matmul(img_emb, shuffled_text.t()).diag()
        neg_sim_text = torch.matmul(text_emb, shuffled_image.t()).diag()
        
        # 计算triplet loss
        img_loss = torch.clamp(self.margin + neg_sim_img - img_text_similarity, min=0)
        text_loss = torch.clamp(self.margin + neg_sim_text - img_text_similarity, min=0)
        
        return (img_loss.mean() + text_loss.mean()) / 2

In [35]:

# In[3]: 实现Embedding Queue
class EmbeddingQueue:
    def __init__(self, max_size=15):
        self.max_size = max_size
        self.image_queue = []
        self.text_queue = []
    
    def add_queue(self, image_embedding, text_embedding):
        """添加新的embeddings到队列并维护最大大小"""
        self.image_queue.append(image_embedding)
        self.text_queue.append(text_embedding)
        
        # 如果队列超过最大大小，移除最旧的embeddings
        if len(self.image_queue) > self.max_size:
            self.image_queue.pop(0)
            self.text_queue.pop(0)
    
    def return_values(self):
        """返回存储的embeddings"""
        return self.image_queue, self.text_queue
    
    def clear(self):
        """清除所有存储的embeddings"""
        self.image_queue = []
        self.text_queue = []


In [36]:
def train_with_queue(model, dataloader, optimizer, scaler, epoch, embedding_queue=None):
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch}')
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        texts = [f"This is an image of a {train_dataset.classes[label]}" for label in labels]
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            img_emb, text_emb = model(images, texts)
            
            if embedding_queue is not None:
                # 更新队列中的embedding
                with torch.no_grad():  # 添加这行，避免梯度计算
                    embedding_queue.add_queue(img_emb.clone().detach(), 
                                           text_emb.clone().detach())
                
                saved_img_embeddings, saved_text_embeddings = embedding_queue.return_values()
                
                if len(saved_img_embeddings) > 0:
                    # 合并当前batch和队列中的embeddings
                    queue_img = torch.cat(saved_img_embeddings, dim=0)
                    queue_text = torch.cat(saved_text_embeddings, dim=0)
                    
                    # 确保维度对齐
                    img_mat = torch.cat([img_emb, queue_img.to(device)], dim=0)
                    text_mat = torch.cat([text_emb, queue_text.to(device)], dim=0)
                else:
                    img_mat = img_emb
                    text_mat = text_emb
            
            # 计算损失时只使用当前batch的标签
            labels = torch.arange(img_emb.shape[0]).to(device)
            logits_img_text = torch.matmul(img_emb, text_mat.t()) * 2
            logits_text_img = torch.matmul(text_emb, img_mat.t()) * 2
            
            img_text_loss = F.cross_entropy(logits_img_text, labels)
            text_img_loss = F.cross_entropy(logits_text_img, labels)
            loss = (img_text_loss + text_img_loss) / 2
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
    return total_loss / len(dataloader)

In [37]:
# In[5]: 实现实验运行和结果记录函数
def run_experiments(model, train_dataset, test_dataset, config):
    results = {
        'original': {'loss': [], 'accuracy': []},
        'triplet': {'loss': [], 'accuracy': []},
        'queue': {'loss': [], 'accuracy': []}
    }
    
    for method in ['original', 'triplet', 'queue']:
        print(f"\nTraining with {method} method...")
        
        # 重置模型
        model.load_state_dict(torch.load('initial_weights.pth'))
        optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'])
        scaler = GradScaler()
        
        if method == 'queue':
            embedding_queue = EmbeddingQueue(max_size=config['queue_size'])
        else:
            embedding_queue = None
        
        for epoch in range(config['epochs']):
            # 训练
            if method == 'triplet':
                loss = train_with_triplet_loss(model, train_loader, optimizer, scaler, epoch)
            elif method == 'queue':
                loss = train_with_queue(model, train_loader, optimizer, scaler, epoch, embedding_queue)
            else:
                loss = train_one_epoch(model, train_loader, optimizer, scaler, epoch)
            
            # 评估
            accuracy = evaluate(model, test_dataset)
            
            # 记录结果
            results[method]['loss'].append(loss)
            results[method]['accuracy'].append(accuracy)
            
            print(f"Epoch {epoch+1}/{config['epochs']}")
            print(f"Loss: {loss:.4f}, Accuracy: {accuracy:.2f}%")
    
    return results

In [38]:
# In[6]: 实现结果可视化函数
def plot_experimental_results(results):
    # 创建图表
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # 绘制损失曲线
    for method in results.keys():
        ax1.plot(results[method]['loss'], label=method)
    ax1.set_title('Training Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # 绘制准确率曲线
    for method in results.keys():
        ax2.plot(results[method]['accuracy'], label=method)
    ax2.set_title('Test Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

In [39]:
# In[7]: 实现基础训练和评估函数
def train_one_epoch(model, dataloader, optimizer, scaler, epoch):
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch}')
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        texts = [f"This is an image of a {train_dataset.classes[label]}" for label in labels]
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            img_emb, text_emb = model(images, texts)
            
            # 计算损失
            labels = torch.arange(img_emb.shape[0]).to(device)
            logits_img_text = torch.matmul(img_emb, text_emb.t()) * 2
            logits_text_img = torch.matmul(text_emb, img_emb.t()) * 2
            
            img_text_loss = F.cross_entropy(logits_img_text, labels)
            text_img_loss = F.cross_entropy(logits_text_img, labels)
            loss = (img_text_loss + text_img_loss) / 2
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        # 清理内存
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
    
    return total_loss / len(dataloader)

def train_with_triplet_loss(model, dataloader, optimizer, scaler, epoch):
    model.train()
    total_loss = 0
    triplet_criterion = TripletLoss()
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch}')
    for batch_idx, (images, labels) in enumerate(pbar):
        images = images.to(device)
        texts = [f"This is an image of a {train_dataset.classes[label]}" for label in labels]
        
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            img_emb, text_emb = model(images, texts)
            loss = triplet_criterion(img_emb, text_emb)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
    
    return total_loss / len(dataloader)

def evaluate(model, test_dataset):
    model.eval()
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc='Evaluating'):
            images = images.to(device)
            labels = labels.to(device)
            
            # 为每个类别生成文本描述
            class_texts = [f"This is an image of a {test_dataset.classes[i]}" for i in range(len(test_dataset.classes))]
            
            # 获取图像嵌入
            img_emb, _ = model(images, ["dummy text"])  # 文本参数在这里不重要
            
            # 获取所有类别的文本嵌入
            _, text_emb = model(torch.zeros_like(images[:1]), class_texts)
            
            # 计算相似度
            similarity = torch.matmul(img_emb, text_emb.t())
            
            # 获取预测
            _, predicted = similarity.max(1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

: 

In [40]:
# In[7]: 运行实验
config = {
    'learning_rate': 1e-4,
    'epochs': 5,
    'batch_size': 32,
    'queue_size': 30
}

# 保存初始权重
torch.save(model.state_dict(), 'initial_weights.pth')

# 运行实验
results = run_experiments(model, train_dataset, test_dataset, config)

# 可视化结果
plot_experimental_results(results)


Training with original method...


Epoch 0:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1/5
Loss: 2.1633, Accuracy: 92.38%


Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2/5
Loss: 2.0398, Accuracy: 92.17%


Epoch 2:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3/5
Loss: 2.0156, Accuracy: 93.86%


Epoch 3:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 4/5
Loss: 1.9966, Accuracy: 94.02%


Epoch 4:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 5/5
Loss: 1.9854, Accuracy: 93.55%

Training with triplet method...


Epoch 0:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1/5
Loss: 0.0371, Accuracy: 85.79%


Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2/5
Loss: 0.0290, Accuracy: 85.94%


Epoch 2:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3/5
Loss: 0.0280, Accuracy: 89.63%


Epoch 3:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 4/5
Loss: 0.0269, Accuracy: 88.55%


Epoch 4:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 5/5
Loss: 0.0272, Accuracy: 89.97%

Training with queue method...


Epoch 0:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 1/5
Loss: 6.1628, Accuracy: 10.08%


Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 2/5
Loss: 6.2687, Accuracy: 10.00%


Epoch 2:   0%|          | 0/1563 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Epoch 3/5
Loss: 6.2226, Accuracy: 10.00%


Epoch 3:   0%|          | 0/1563 [00:00<?, ?it/s]

In [28]:
# In[8]: 打印最终结果
print("\nFinal Results:")
for method in results.keys():
    final_accuracy = results[method]['accuracy'][-1]
    print(f"{method.capitalize()} method: {final_accuracy:.2f}% accuracy")


Final Results:
Original method: 93.20% accuracy
Triplet method: 90.05% accuracy
Queue method: 10.18% accuracy
