In [3]:
from datasets import load_dataset
import datasets
from collections import Counter
import re
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
import torch
import matplotlib.pyplot as plt
import tokenizers
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import os
import sys
from IPython.display import clear_output
from torch.utils.data import random_split
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
plt.ion()  # 交互模式
device=torch.device('cuda:0')

In [4]:
# 加载 IMDB 数据集
cache_dir=r'.\data_file'
imdb_dataset = load_dataset("imdb",cache_dir=cache_dir)

class IMDBDataset(Dataset):
    def __init__(self, data_type):
        super(IMDBDataset, self).__init__()
        self.dataset=imdb_dataset[data_type]
    def __getitem__(self, index):
        return self.dataset[index]
    def __len__(self):
        return len(self.dataset)

Using the latest cached version of the dataset since imdb couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at data_file\imdb\plain_text\0.0.0\e6281661ce1c48d982bc483cf8a173c1bbeb5d31 (last modified on Thu Jun  5 21:25:29 2025).


In [5]:
imdb_data=IMDBDataset('train')
imdb_data_test=IMDBDataset('test')

In [7]:
test_size = int(0.8 * len(imdb_data_test))  # 20000
val_size = len(imdb_data_test) - test_size  # 5000
test_dataset,val_dataset=random_split(imdb_data_test, [test_size, val_size])

In [8]:
from transformers import AutoModel, AutoTokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
pretrain_model = AutoModel.from_pretrained('bert-base-uncased',cache_dir=r'D:\bigdata_project\models_file')
pretrain_model=pretrain_model.to(device)
token_tool = AutoTokenizer.from_pretrained('bert-base-uncased',cache_dir=r'D:\bigdata_project\models_file')

In [11]:
import re
from bs4 import BeautifulSoup  # 用于去除HTML标签

def clean_text(text):
    # 去除HTML标签（常见于IMDB数据集）
    text = BeautifulSoup(text, "html.parser").get_text()
    
    # 替换或删除特殊字符（保留基本标点）
    text = re.sub(r"@[\w]+", "", text)                # 移除@提及
    text = re.sub(r"http\S+", "", text)               # 移除URL
    # text = re.sub(r"[^a-zA-Z0-9!?.,:;'\"\-]", " ", text)  # 保留基本字符，其他替换为空格
    text = re.sub(r"\s+", " ", text)                  # 合并多个空格
    
    # 可选：处理缩写（如 don't → do not）
    text = text.replace("n't", " not").replace("'s", " is").strip().lower()
    
    return text.strip()

def token_embed(text_dict):
    text=[clean_text(sample['text']) for sample in text_dict]
    label=torch.tensor([sample['label'] for sample in text_dict])

    tokenresult=token_tool(text, max_length=512,padding='max_length',return_tensors='pt',truncation=True,add_special_tokens=True)
    return {'input_ids':tokenresult['input_ids']
            ,'token_type_ids':tokenresult['token_type_ids']
            ,'attention_mask':tokenresult['attention_mask']
            ,'label':label}

train_data_iter=DataLoader(imdb_data,batch_size=64,shuffle=True,drop_last=True,collate_fn=token_embed)
val_data_iter=DataLoader(val_dataset,batch_size=64,shuffle=True,drop_last=True,collate_fn=token_embed)
test_data_iter=DataLoader(test_dataset,batch_size=64,shuffle=True,drop_last=True,collate_fn=token_embed)

In [12]:
class Bert(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_bert=pretrain_model
        self.fc1=nn.Linear(768, 256)
        self.relu=nn.ReLU()
        self.fc2=nn.Linear(256, 2)
        self.dropout=nn.Dropout(0.2)
        self.sequential=nn.Sequential(self.fc1,self.relu,self.dropout,self.fc2)
    def forward(self,input_ids,token_type_ids,attention_mask):
        with torch.no_grad():
            out_temp=self.pre_bert(input_ids=input_ids,token_type_ids=token_type_ids,attention_mask=attention_mask)
        out_temp=out_temp.last_hidden_state[:,0]
        return self.sequential(out_temp)
bert=Bert().to(device)

In [13]:
writer= SummaryWriter('runs/experiment_1')

In [None]:
optimizer=optim.Adam(bert.parameters(),lr=1e-5)
loss_fn=nn.CrossEntropyLoss()

loss_list=[]
train_accuracy_list=[]
accuracy_list=[]

for epoch in range(500):
    for i,sample in enumerate(train_data_iter):
        bert.train()
        sample={k:v.to(device) for k,v in sample.items()}
        optimizer.zero_grad()
        result=bert(sample['input_ids'],sample['token_type_ids'],sample['attention_mask'])
        loss=loss_fn(result,sample['label'])
        loss.backward()
        optimizer.step()

        if i%100==0:
            bert.eval()
            accuracy_rate=0
            for num,sample in enumerate(val_data_iter):
                sample={k:v.to(device) for k,v in sample.items()}
                with torch.no_grad():
                    result=bert(sample['input_ids'],sample['token_type_ids'],sample['attention_mask'])
                    accuracy_rate+=(((torch.argmax(result,dim=1)-sample['label'])==0).sum()/64).cpu().detach().numpy()
            accuracy_rate=accuracy_rate/(num+1)
            
            accuracy_list.append(accuracy_rate)
            loss_list.append(loss.cpu().detach().numpy())
            train_accuracy_list.append((((torch.argmax(result,dim=1)-sample['label'])==0).sum()/64).cpu().detach().numpy())
            print(f'损失值：{loss_list[-1]},训练集准确度：{train_accuracy_list[-1]},测试集准确度：{accuracy_list[-1]}')
            writer.add_scalar('loss',loss_list[-1],epoch+i//100)
            writer.add_scalar('train_accuracy',train_accuracy_list[-1],epoch+i//100)
            writer.add_scalar('accuracy_rate',accuracy_list[-1],epoch+i//100)


            #     # 实时绘图
            # plt.clf()
            # plt.plot(loss_list, label='Train Loss')
            # plt.plot(accuracy_list, label='Validation accuracy')
            # plt.plot(train_accuracy_list, label='Train accuracy')
            # plt.xlabel('Epoch')
            # plt.ylabel('Loss')
            # plt.legend()
            # plt.pause(0.1)

In [69]:
(torch.argmax(result,dim=1)-torch.tensor(sample['label'],dtype=torch.long)==0).sum()

tensor(14)

In [74]:
bert

Bert(
  (fc1): Linear(in_features=768, out_features=256, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=256, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
  (sequential): Sequential(
    (0): Linear(in_features=768, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=256, out_features=2, bias=True)
  )
)