In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from transformers import BertTokenizer, BertModel
import torch
from torch import nn
import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_csv(train_path, test_path):
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)

    train_reviews = train_df[train_df['Product Class'] != 'Else']['review'].tolist()
    train_ratings = train_df[train_df['Product Class'] != 'Else']['rating'].tolist()

    test_reviews = test_df[test_df['Product Class'] != 'Else']['review'].tolist()
    test_ratings = test_df[test_df['Product Class'] != 'Else']['rating'].tolist()

    return train_reviews,train_ratings, test_reviews, test_ratings

train_reviews,train_ratings, test_reviews, test_ratings = load_csv('./data/drugsComTrain_raw_addclass.csv', './data/drugsComTest_raw_addclass.csv')

In [3]:
print(torch.unique(torch.tensor(train_ratings)))

tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])


In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
def tokenize(tokenizer, train_reviews, test_reviews):
    train_reviews_token = [tokenizer.encode_plus(
    text,
    truncation=True,
    add_special_tokens=True,
    max_length=512,            
    pad_to_max_length=True,  
    return_attention_mask=True,  
    return_tensors='pt',      
    ) for text in train_reviews]

    test_reviews_token = [tokenizer.encode_plus(
    text,
    truncation=True,
    add_special_tokens=True,
    max_length=512,            
    pad_to_max_length=True,  
    return_attention_mask=True,  
    return_tensors='pt',      
    ) for text in test_reviews]

    return train_reviews_token, test_reviews_token


train_reviews_token, test_reviews_token = tokenize(tokenizer, train_reviews, test_reviews)



In [6]:
class Review_Rating_Dataset(torch.utils.data.Dataset):
    def __init__(self, reviews_token, rating):
        self.review = reviews_token
        self.rating = rating
 
    def __getitem__(self, idx):
        item = {k: v.squeeze(dim=0) for k, v in self.review[idx].items()}
        item["rating"] = torch.tensor(self.rating[idx] - 1)
        return item
 
    def __len__(self):
        return len(self.rating)


train_dataset = Review_Rating_Dataset(train_reviews_token, train_ratings)
test_dataset = Review_Rating_Dataset(test_reviews_token, test_ratings)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
class BertWithMLP(nn.Module):
    def __init__(self, bert, hidden_size=768, mlp_hidden_size1=1024, mlp_hidden_size2 =256, num_classes=1):
        super(BertWithMLP, self).__init__()
        self.bert = bert
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_size1),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(mlp_hidden_size1, mlp_hidden_size2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(mlp_hidden_size2, num_classes)
        )
    
    def forward(self, input_ids, attention_mask):

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        cls = outputs.last_hidden_state[:, 0, :]
        
        logits = self.mlp(cls).squeeze(-1)
        
        return logits

In [10]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct_predictions = 0
    total_error = 0.0
    
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['rating'].to(device)

        
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        
        preds = torch.round(outputs)
        correct_predictions += torch.sum(preds == labels)
        total_error += torch.sum(torch.abs(labels - outputs))
        total_loss += loss.item()
        
        # 更新进度条显示
        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': torch.sum(preds == labels).item()/len(labels),
            'error': torch.mean(torch.abs(labels - outputs)).item()
        })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions.double() / len(dataloader.dataset)
    error = total_error.item() / len(dataloader.dataset)
    return avg_loss, accuracy, error

def eval_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct_predictions = 0
    total_error = 0.0
    
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Evaluating", leave=False)
        for batch in progress_bar:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['rating'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels.float())
            
            preds = torch.round(outputs)
            correct_predictions += torch.sum(preds == labels)
            total_error += torch.sum(torch.abs(labels - outputs))
            total_loss += loss.item()
            
            progress_bar.set_postfix({
                'loss': loss.item(),
                'acc': torch.sum(preds == labels).item()/len(labels),
                'error': torch.mean(torch.abs(labels - outputs)).item()
            })
    
    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions.double() / len(dataloader.dataset)
    error = total_error.item() / len(dataloader.dataset)
    return avg_loss, accuracy, error

# 4. 主训练循环
def train_and_evaluate(
    model, 
    train_loader, 
    val_loader, 
    optimizer, 
    criterion, 
    device, 
    epochs, 
    model_save_path,
    eval_every=1  # 每多少轮评估一次
):
    best_val_error = 0.0
    best_val_acc = 0.0
    history = {
        'train_loss': [],
        'train_acc': [],
        'train_error': [],
        'val_loss': [],
        'val_acc': [],
        'val_error': []
    }
    
    for epoch in range(1, epochs+1):
        print(f"\nEpoch {epoch}/{epochs}")
        
        # 训练阶段
        train_loss, train_acc, train_error = train_epoch(
            model, train_loader, optimizer, criterion, device)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc.item())
        history['train_error'].append(train_error)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train Error: {train_error:.4f}")
        
        # 验证阶段
        if epoch % eval_every == 0 and val_loader is not None:
            val_loss, val_acc, val_error = eval_model(
                model, val_loader, criterion, device)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc.item())
            history['val_error'].append(val_error)
            
            print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Error: {val_error:.4f}")
            
            # 保存最佳模型
            if val_error > best_val_error:
                best_val_error = val_error
                torch.save(model.state_dict(), model_save_path)
                print(f"New best model saved to {model_save_path} with val_acc: {val_acc:.4f} | val_error: {val_error:.4f}")

                continue
            
            if val_acc > best_val_acc:
                best_val_error = val_acc
                torch.save(model.state_dict(), model_save_path)
                print(f"New best model saved to {model_save_path} with val_acc: {val_acc:.4f} | val_error: {val_error:.4f}")
    
    return history

# 5. 主函数
def main():
    # 初始化
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    BERT = BertModel.from_pretrained("bert-base-uncased")
    # for param in BERT.parameters():
    #     param.requires_grad = False

    # layers_to_unfreeze = ['encoder.layer.11', 'encoder.layer.10']  # 解冻最后两层
    # for name, param in model.named_parameters():
    #     if any(layer in name for layer in layers_to_unfreeze):
    #         param.requires_grad = True
    
    # 加载模型
    model = BertWithMLP(BERT, hidden_size=768, mlp_hidden_size1=256, mlp_hidden_size2=10, num_classes=1)
    model.to(device)

    # for name, param in model.named_parameters():
    #     print(name, param.requires_grad)
    
    # 训练参数
    optimizer = torch.optim.Adam([
    {'params': model.bert.encoder.layer[-1].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-2].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-3].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-4].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-5].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-6].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-7].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-8].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-9].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-10].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-11].parameters(), 'lr': 5e-5},
    {'params': model.bert.encoder.layer[-12].parameters(), 'lr': 5e-5},
    {'params': model.mlp.parameters(), 'lr': 1e-4}
    ])

    criterion = torch.nn.MSELoss()
    epochs = 20
    model_save_path = "./best_model.pth"
    
    # 创建保存目录
    os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
    
    # 训练和验证
    history = train_and_evaluate(
        model=model,
        train_loader=train_loader,
        val_loader=test_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        epochs=epochs,
        model_save_path=model_save_path,
        eval_every=1  # 每轮都验证
    )
    
    print("\nTraining complete!")
    print(f"Best validation accuracy: {max(history['val_acc']):.4f}")
    print(f"Best validation error: {max(history['val_error']):.4f}")

if __name__ == "__main__":
    main()


Epoch 1/20


                                                                                             

KeyboardInterrupt: 