# CLIPText Implementation with AG News Dataset

This notebook implements the CLIPText approach for zero-shot text classification using the AG News dataset. AG News contains news articles in 4 categories:
1. World
2. Sports
3. Business
4. Tech/Sci

In [2]:
# Cell 1: 导入必要的包
import torch
import clip
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
from sklearn.metrics import classification_report
import os
import torch.nn.functional as F
from PIL import Image, ImageDraw
import re
from nltk.stem import PorterStemmer
from nltk.tokenize import word_tokenize
import nltk
nltk.download('punkt')

# 加载模型和数据集
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
dataset = load_dataset("ag_news")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Found cached dataset parquet (/root/.cache/huggingface/datasets/parquet/ag_news-92271709ed454db0/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Cell 2: 创建类别图像和特征
categories = {
    0: "World News",
    1: "Sports News", 
    2: "Business News",
    3: "Technology and Science News"
}

# 为每个类别定义关键词和特征
category_keywords = {
    0: ["global", "international", "world", "country", "nation", "political", "government"],
    1: ["sports", "game", "player", "team", "tournament", "championship", "athlete"],
    2: ["business", "market", "economy", "company", "financial", "trade", "stock"],
    3: ["technology", "science", "research", "innovation", "digital", "tech", "scientific"]
}

def create_category_image(label, size=(224, 224), variation=0):
    """创建更具区分性的类别图像"""
    img = Image.new('RGB', size, 'white')
    draw = ImageDraw.Draw(img)
    
    # 基础渐变
    arr = np.zeros((size[0], size[1], 3), dtype=np.uint8)
    
    if label == 0:  # World News
        arr[:, :, 2] = np.linspace(100, 255, size[0])[:, None]
        if variation == 0:
            draw.ellipse([40, 40, 184, 184], outline='white', width=4)
            draw.line([40, 112, 184, 112], fill='white', width=3)
        elif variation == 1:
            for i in range(4):
                x = 60 + i * 40
                draw.ellipse([x, 90, x+30, 120], outline='white', width=2)
        else:
            for i in range(4):
                draw.line([0, i*60, 224, 224-i*60], fill='white', width=2)
                
    elif label == 1:  # Sports News
        arr[:, :, 1] = np.linspace(100, 255, size[0])[:, None]
        if variation == 0:
            draw.ellipse([70, 70, 154, 154], outline='white', width=4)
            draw.line([70, 112, 154, 112], fill='white', width=2)
        elif variation == 1:
            draw.rectangle([50, 50, 174, 174], outline='white', width=3)
            draw.line([112, 50, 112, 174], fill='white', width=2)
        else:
            # 使用线条绘制人形图标
            points = [(112, 50), (90, 100), (134, 100)]
            # 绘制三角形的三条边
            draw.line([points[0], points[1]], fill='white', width=2)
            draw.line([points[1], points[2]], fill='white', width=2)
            draw.line([points[2], points[0]], fill='white', width=2)
            # 绘制身体
            draw.line([112, 100, 112, 150], fill='white', width=2)
            
    elif label == 2:  # Business News
        arr[:, :, 0] = np.linspace(100, 255, size[0])[:, None]
        if variation == 0:
            points = [(50,174), (90,130), (130,90), (174,50)]
            draw.line(points, fill='white', width=3)
        elif variation == 1:
            draw.text((90, 90), "$", fill='white', size=60)
        else:
            for i in range(4):
                height = 40 + i * 30
                draw.rectangle([50+i*40, 224-height, 80+i*40, 224], outline='white', width=2)
            
    else:  # Tech News
        arr[:, :, 0] = np.linspace(100, 200, size[0])[:, None]
        arr[:, :, 2] = np.linspace(100, 200, size[0])[:, None]
        if variation == 0:
            for i in range(5):
                draw.line([0, i*50, 224, i*50], fill='white', width=2)
                draw.line([i*50, 0, i*50, 224], fill='white', width=2)
        elif variation == 1:
            for i in range(5):
                for j in range(5):
                    draw.text((30+i*40, 30+j*40), "01", fill='white')
        else:
            draw.rectangle([60, 60, 164, 164], outline='white', width=3)
            draw.line([60, 112, 164, 112], fill='white', width=2)
            draw.line([112, 60, 112, 164], fill='white', width=2)
    
    img = Image.fromarray(arr)
    return img

# 创建和处理类别图像
print("Creating category images...")
category_images = {}
for label in categories.keys():
    print(f"Processing category: {categories[label]}")
    
    # 创建多个变体
    processed_variants = []
    for i in range(3):  # 每个类别创建3个变体
        img = create_category_image(label, variation=i)
        img_path = f'agnews_images/category_{label}_variant_{i}.jpg'
        img.save(img_path)
        print(f"Saved image variant {i} to {img_path}")
        
        processed_img = preprocess(img).unsqueeze(0).to(device)
        processed_variants.append(processed_img)
    
    category_images[label] = processed_variants
    print(f"Successfully processed all variants for {categories[label]}\n")

# 预处理并缓存图像特征
print("Preprocessing category images...")
category_features = {}
for label, img_variants in category_images.items():
    features = []
    with torch.no_grad():
        for img in img_variants:
            features.append(model.encode_image(img))
    category_features[label] = features
    print(f"Processed features for {categories[label]}")

Creating category images...
Processing category: World News
Saved image variant 0 to agnews_images/category_0_variant_0.jpg
Saved image variant 1 to agnews_images/category_0_variant_1.jpg
Saved image variant 2 to agnews_images/category_0_variant_2.jpg
Successfully processed all variants for World News

Processing category: Sports News
Saved image variant 0 to agnews_images/category_1_variant_0.jpg
Saved image variant 1 to agnews_images/category_1_variant_1.jpg
Saved image variant 2 to agnews_images/category_1_variant_2.jpg
Successfully processed all variants for Sports News

Processing category: Business News
Saved image variant 0 to agnews_images/category_2_variant_0.jpg
Saved image variant 1 to agnews_images/category_2_variant_1.jpg
Saved image variant 2 to agnews_images/category_2_variant_2.jpg
Successfully processed all variants for Business News

Processing category: Technology and Science News
Saved image variant 0 to agnews_images/category_3_variant_0.jpg
Saved image variant 1 t

In [4]:
# Cell 3: 定义预测函数
def clean_text(text):
    """增强的文本清理函数"""
    text = re.sub(r'<[^>]+>', '', text)
    text = re.sub(r'http\S+|www.\S+', '', text)
    text = re.sub(r'[^\w\s.,!?-]', ' ', text)
    text = ' '.join(text.split())
    
    ps = PorterStemmer()
    words = word_tokenize(text)
    text = ' '.join([ps.stem(word) for word in words])
    
    return text

def predict_basic_clip(text):
    """基础CLIP方法 - 不使用prompt"""
    text_input = clip.tokenize([text[:250]]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_input)
        
        similarities = {}
        for label, img_features_list in category_features.items():
            img_similarities = []
            for img_features in img_features_list:
                similarity = F.cosine_similarity(text_features, img_features)
                img_similarities.append(similarity.item())
            similarities[label] = np.mean(img_similarities)
            
        return max(similarities.items(), key=lambda x: x[1])[0]

def predict_simple_prompt(text):
    """简单Prompt方法 - 使用单一prompt模板"""
    prompted_text = f"This is a news article about: {text[:250]}"
    text_input = clip.tokenize([prompted_text]).to(device)
    
    with torch.no_grad():
        text_features = model.encode_text(text_input)
        
        similarities = {}
        for label, img_features_list in category_features.items():
            img_similarities = []
            for img_features in img_features_list:
                similarity = F.cosine_similarity(text_features, img_features)
                img_similarities.append(similarity.item())
            similarities[label] = np.mean(img_similarities)
            
        return max(similarities.items(), key=lambda x: x[1])[0]

def predict_category_with_confidence(text):
    """完整的PROMPT-CLIPTEXT方法"""
    cleaned_text = clean_text(text)
    
    # 计算文本中包含的类别关键词
    keyword_scores = {label: 0 for label in categories.keys()}
    for label, keywords in category_keywords.items():
        for keyword in keywords:
            if keyword in cleaned_text.lower():
                keyword_scores[label] += 1
    
    max_score = max(keyword_scores.values()) + 1e-6
    keyword_weights = {k: v/max_score for k, v in keyword_scores.items()}
    
    # 创建多个prompt
    prompts = [
        f"This is a news article discussing: {cleaned_text[:250]}",
        f"Here's a news report about: {cleaned_text[:250]}",
        f"Breaking news story: {cleaned_text[:250]}"
    ]
    
    # 添加类别特定的prompt
    category_prompts = {
        0: [f"International news coverage: {cleaned_text[:250]}"],
        1: [f"Sports coverage about: {cleaned_text[:250]}"],
        2: [f"Business and finance report: {cleaned_text[:250]}"],
        3: [f"Tech and science report about: {cleaned_text[:250]}"]
    }
    
    for category_specific_prompts in category_prompts.values():
        prompts.extend(category_specific_prompts)
    
    similarities = {label: [] for label in categories.keys()}
    
    for prompted_text in prompts:
        text_input = clip.tokenize([prompted_text]).to(device)
        with torch.no_grad():
            text_features = model.encode_text(text_input)
            
            for label, img_features_list in category_features.items():
                prompt_similarities = []
                for img_features in img_features_list:
                    similarity = F.cosine_similarity(text_features, img_features)
                    prompt_similarities.append(similarity.item())
                
                top_similarities = sorted(prompt_similarities, reverse=True)[:2]
                similarities[label].append(np.mean(top_similarities))
    
    final_scores = {}
    for label in categories.keys():
        clip_score = np.mean(similarities[label])
        keyword_weight = keyword_weights[label]
        
        class_weight = 1.0
        if label == 1:  # Sports News
            class_weight = 1.2
        elif label == 2:  # Business News
            class_weight = 0.8
            
        final_scores[label] = clip_score * (0.7 + 0.3 * keyword_weight) * class_weight
    
    max_score = max(final_scores.values())
    max_label = max(final_scores.items(), key=lambda x: x[1])[0]
    
    confidence = max_score / (sum(final_scores.values()) + 1e-6)
    
    if confidence < 0.3:
        sorted_scores = sorted(final_scores.items(), key=lambda x: x[1], reverse=True)
        if len(sorted_scores) > 1:
            max_label = sorted_scores[1][0]
    
    return max_label, confidence, final_scores


In [5]:
# Cell 4: 对比实验
def evaluate_method(predict_fn, test_texts, test_labels, method_name):
    predictions = []
    correct = 0
    total = 0
    
    print(f"\nEvaluating {method_name}...")
    for i, (text, true_label) in enumerate(zip(test_texts, test_labels)):
        try:
            if predict_fn == predict_category_with_confidence:
                pred, _, _ = predict_fn(text)
            else:
                pred = predict_fn(text)
                
            predictions.append(pred)
            is_correct = pred == true_label
            if is_correct:
                correct += 1
            total += 1
            
            if (i + 1) % 20 == 0:
                print(f"Progress: {i+1}/{len(test_texts)}, Current accuracy: {correct/total:.2%}")
                
        except Exception as e:
            print(f"Error processing text: {text[:50]}...")
            print(f"Error: {str(e)}")
            predictions.append(0)
    
    print(f"\n=== {method_name} Classification Report ===")
    report = classification_report(test_labels, predictions, 
                                 target_names=[categories[i] for i in range(4)],
                                 output_dict=True)
    print(classification_report(test_labels, predictions, 
                              target_names=[categories[i] for i in range(4)]))
    return report

# 运行对比实验
test_size = 100
test_texts = dataset['test']['text'][:test_size]
test_labels = dataset['test']['label'][:test_size]

# 评估三种方法
basic_results = evaluate_method(predict_basic_clip, test_texts, test_labels, "Basic CLIP")
simple_prompt_results = evaluate_method(predict_simple_prompt, test_texts, test_labels, "Simple Prompt")
full_prompt_results = evaluate_method(predict_category_with_confidence, test_texts, test_labels, "PROMPT-CLIPTEXT")

# 打印对比结果
print("\n=== Methods Comparison ===")
print(f"{'Method':<20} {'Accuracy':<10} {'Macro F1':<10}")
print("-" * 40)
print(f"{'Basic CLIP':<20} {basic_results['accuracy']:.3f}     {basic_results['macro avg']['f1-score']:.3f}")
print(f"{'Simple Prompt':<20} {simple_prompt_results['accuracy']:.3f}     {simple_prompt_results['macro avg']['f1-score']:.3f}")
print(f"{'PROMPT-CLIPTEXT':<20} {full_prompt_results['accuracy']:.3f}     {full_prompt_results['macro avg']['f1-score']:.3f}")


Evaluating Basic CLIP...
Progress: 20/100, Current accuracy: 15.00%
Progress: 40/100, Current accuracy: 10.00%
Progress: 60/100, Current accuracy: 15.00%
Progress: 80/100, Current accuracy: 15.00%
Error processing text: Staples Profit Up, to Enter China Market  NEW YORK...
Error: Input Staples Profit Up, to Enter China Market  NEW YORK (Reuters) - Staples Inc. &lt;A HREF="http://www.investor.reuters.com/FullQuote.aspx?ticker=SPLS.O target=/stocks/quickinfo/fullquote"&gt;SPLS.O&lt;/A&gt;, the top U.S.  office products retailer, on T is too long for context length 77
Progress: 100/100, Current accuracy: 18.18%

=== Basic CLIP Classification Report ===
                             precision    recall  f1-score   support

                 World News       0.18      0.07      0.10        30
                Sports News       0.00      0.00      0.00        21
              Business News       0.13      0.83      0.22        12
Technology and Science News       0.86      0.16      0.27      

In [6]:
# Cell 5: 基准模型比较
def train_evaluate_xgboost():
    print("\nTraining and evaluating XGBoost + TF-IDF (Few-shot Learning)...")
    start_time = time.time()
    
    # TF-IDF特征提取
    print("Extracting TF-IDF features...")
    vectorizer = TfidfVectorizer(max_features=5000)
    X_train = vectorizer.fit_transform(train_texts)
    X_test = vectorizer.transform(test_texts)
    
    # 训练XGBoost
    print("Training XGBoost model...")
    model = xgb.XGBClassifier(
        max_depth=7,
        learning_rate=0.1,
        n_estimators=100,
        use_label_encoder=False,
        eval_metric='mlogloss',
        verbose=1
    )
    
    # 训练模型
    eval_set = [(X_test, test_labels)]
    model.fit(X_train, train_labels, eval_set=eval_set, verbose=True)
    
    # 预测
    print("Making predictions...")
    predictions = model.predict(X_test)
    
    training_time = time.time() - start_time
    
    # 计算指标
    report = classification_report(test_labels, predictions, 
                                 target_names=[categories[i] for i in range(4)],
                                 output_dict=True)
    
    return report, training_time, predictions

# 运行基准模型
print("\nRunning benchmark comparisons...")
results = {}

# BERT
bert_report, bert_time, bert_preds = train_evaluate_bert()
results['BERT (Few-shot)'] = {'report': bert_report, 'time': bert_time, 'predictions': bert_preds}

# XGBoost
xgboost_report, xgboost_time, xgboost_preds = train_evaluate_xgboost()
results['XGBoost (Few-shot)'] = {'report': xgboost_report, 'time': xgboost_time, 'predictions': xgboost_preds}

# 为PROMPT-CLIPTEXT生成预测结果
prompt_predictions = []
for text in test_texts[:100]:  # 使用与之前相同的100个测试样本
    try:
        pred, _, _ = predict_category_with_confidence(text)
        prompt_predictions.append(pred)
    except Exception as e:
        prompt_predictions.append(0)
        print(f"Error in PROMPT-CLIPTEXT prediction: {str(e)}")

# 添加CLIP结果
results['PROMPT-CLIPTEXT (Zero-shot)'] = {
    'report': full_prompt_results,
    'time': None,
    'predictions': prompt_predictions
}

# 打印比较结果
print("\n=== Model Comparison ===")
print(f"{'Model':<25} {'Accuracy':<10} {'Macro F1':<10} {'Training Time':<15}")
print("-" * 60)

for model_name, result in results.items():
    accuracy = result['report']['accuracy']
    macro_f1 = result['report']['macro avg']['f1-score']
    time_str = f"{result['time']:.2f}s" if result['time'] else "N/A"
    print(f"{model_name:<25} {accuracy:.3f}     {macro_f1:.3f}     {time_str}")

# 打印详细分类报告
for model_name, result in results.items():
    print(f"\n=== {model_name} Detailed Report ===")
    if model_name == 'PROMPT-CLIPTEXT (Zero-shot)':
        # 只使用前100个样本进行评估
        print(classification_report(test_labels[:100], result['predictions'],
                                 target_names=[categories[i] for i in range(4)]))
    else:
        # 使用全部测试集
        print(classification_report(test_labels, result['predictions'],
                                 target_names=[categories[i] for i in range(4)]))


Running benchmark comparisons...


NameError: name 'train_evaluate_bert' is not defined