In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, AdamW
from transformers import BertModel, BertConfig
from tqdm import tqdm

In [3]:
# Step 1: Define the Dataset Class
class TokenDataset(Dataset):
    def __init__(self, file_path):
        with open(file_path, 'r') as f:
            # Read and split the data into tokens
            self.data = f.read().split(' ')
        
        # Create input-output pairs
        self.inputs = []
        self.labels = []
        
        for i in range(len(self.data) - 1):
            self.inputs.append(self.data[i])
            # Define the label: 0 for space, 1 for newline, 2 for other
            if self.data[i + 1] == "<space>":
                self.labels.append(0)  # Space
            elif self.data[i + 1] == "<newline>":
                self.labels.append(1)  # Newline
            else:
                self.labels.append(2)  # Other

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

In [32]:
# Load the dataset
file_path = 'data/preprocessed_cpp_data.txt'  # Update this with your actual file path
dataset = TokenDataset(file_path)

# Create DataLoader
batch_size = 32  # or your defined batch size
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Get and print dataset and DataLoader size
dataset_size = len(dataset)
num_batches = (dataset_size + batch_size - 1) // batch_size

print(f"Total dataset size: {dataset_size}")
print(f"Batch size: {batch_size}")
print(f"Number of batches: {num_batches}")

Total dataset size: 77431
Batch size: 32
Number of batches: 2420


In [28]:
for inputs, labels in dataloader:
    print(inputs)
    print(labels)
    break

('/', '/', '<space>', 'Copyright', '<space>', '(', 'c', ')', '<space>', '2011', '<space>', 'The', '<space>', 'LevelDB', '<space>', 'Authors', '.', '<space>', 'All', '<space>', 'rights', '<space>', 'reserved', '.', '<newline>', '/', '/', '<space>', 'Use', '<space>', 'of', '<space>')
tensor([2, 0, 2, 0, 2, 2, 2, 0, 2, 0, 2, 0, 2, 0, 2, 2, 0, 2, 0, 2, 0, 2, 2, 1,
        2, 2, 0, 2, 0, 2, 0, 2])


In [31]:
print(dataloader.size)

AttributeError: 'DataLoader' object has no attribute 'size'

In [10]:
# Step 3: Define a BERT Model from Scratch
class BertForSequenceClassification(nn.Module):
    def __init__(self, num_labels=3):
        super(BertForSequenceClassification, self).__init__()
        # Define the BERT configuration
        self.config = BertConfig(
            vocab_size=30522,  # This is the vocabulary size for BERT
            hidden_size=768,  # The size of the hidden layers
            num_hidden_layers=12,  # Number of transformer blocks
            num_attention_heads=12,  # Number of attention heads
            intermediate_size=3072,  # Size of the feed-forward layer
            hidden_act='gelu',  # Activation function
            num_labels=num_labels  # Number of labels for classification
        )
        # Define BERT model
        self.bert = BertModel(self.config)
        # Define classification head
        self.classifier = nn.Linear(self.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]  # Get the pooled output
        logits = self.classifier(pooled_output)
        return logits

In [11]:
# Step 4: Set Device and Initialize Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification(num_labels=3).to(device)

In [13]:
# Step 5: Define Tokenizer and Optimizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')  # Use a pre-trained tokenizer
optimizer = AdamW(model.parameters(), lr=5e-5)

In [33]:
# Step 6: Training Loop
epochs = 3
model.train()  # Set the model to training mode

for epoch in range(epochs):
    total_loss = 0
    for batch_idx, (inputs, labels) in enumerate(tqdm(dataloader)):
        # Tokenize the inputs
        inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt', max_length=128).to(device)

        # Convert labels to tensor correctly
        labels = torch.tensor(labels).long().clone().detach().to(device)

        # Forward pass
        outputs = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        loss_fn = nn.CrossEntropyLoss()
        loss = loss_fn(outputs, labels)
        total_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Print the loss for the current mini-batch
        print(f"Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(dataloader)}], Loss: {loss.item()}")

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss}")


  labels = torch.tensor(labels).long().clone().detach().to(device)
  0%|                                                                               | 1/2420 [00:02<1:40:13,  2.49s/it]

Epoch [1/3], Batch [1/2420], Loss: 0.9389889240264893


  0%|                                                                               | 2/2420 [00:05<1:50:26,  2.74s/it]

Epoch [1/3], Batch [2/2420], Loss: 0.8532786965370178


  0%|                                                                               | 3/2420 [00:08<2:02:19,  3.04s/it]

Epoch [1/3], Batch [3/2420], Loss: 0.826081395149231


  0%|▏                                                                              | 4/2420 [00:12<2:11:26,  3.26s/it]

Epoch [1/3], Batch [4/2420], Loss: 0.9342685341835022


  0%|▏                                                                              | 5/2420 [00:14<1:58:17,  2.94s/it]

Epoch [1/3], Batch [5/2420], Loss: 0.9509748220443726


  0%|▏                                                                              | 6/2420 [00:19<2:21:54,  3.53s/it]

Epoch [1/3], Batch [6/2420], Loss: 0.8352755308151245


  0%|▏                                                                              | 7/2420 [00:21<2:06:29,  3.15s/it]

Epoch [1/3], Batch [7/2420], Loss: 0.8714090585708618


  0%|▎                                                                              | 8/2420 [00:24<1:56:49,  2.91s/it]

Epoch [1/3], Batch [8/2420], Loss: 0.8526116013526917


  0%|▎                                                                              | 9/2420 [00:26<1:50:13,  2.74s/it]

Epoch [1/3], Batch [9/2420], Loss: 0.8827657699584961


  0%|▎                                                                             | 10/2420 [00:28<1:45:30,  2.63s/it]

Epoch [1/3], Batch [10/2420], Loss: 0.6861468553543091


  0%|▎                                                                             | 11/2420 [00:32<1:54:56,  2.86s/it]

Epoch [1/3], Batch [11/2420], Loss: 0.7749091982841492


  0%|▍                                                                             | 12/2420 [00:34<1:49:20,  2.72s/it]

Epoch [1/3], Batch [12/2420], Loss: 0.7767430543899536


  1%|▍                                                                             | 13/2420 [00:38<1:58:06,  2.94s/it]

Epoch [1/3], Batch [13/2420], Loss: 0.848986029624939


  1%|▍                                                                             | 14/2420 [00:43<2:32:34,  3.80s/it]

Epoch [1/3], Batch [14/2420], Loss: 0.9186177849769592


  1%|▍                                                                             | 15/2420 [00:47<2:30:57,  3.77s/it]

Epoch [1/3], Batch [15/2420], Loss: 0.7950199842453003


  1%|▌                                                                             | 16/2420 [00:50<2:14:23,  3.35s/it]

Epoch [1/3], Batch [16/2420], Loss: 0.8984153270721436


  1%|▌                                                                             | 17/2420 [00:52<2:05:42,  3.14s/it]

Epoch [1/3], Batch [17/2420], Loss: 0.7576972246170044


  1%|▌                                                                             | 18/2420 [00:55<2:03:32,  3.09s/it]

Epoch [1/3], Batch [18/2420], Loss: 0.6971703171730042


  1%|▌                                                                             | 19/2420 [00:59<2:17:42,  3.44s/it]

Epoch [1/3], Batch [19/2420], Loss: 0.8344874978065491


  1%|▋                                                                             | 20/2420 [01:02<2:05:33,  3.14s/it]

Epoch [1/3], Batch [20/2420], Loss: 0.5822280049324036


  1%|▋                                                                             | 21/2420 [01:05<2:06:08,  3.15s/it]

Epoch [1/3], Batch [21/2420], Loss: 0.9260549545288086


  1%|▋                                                                             | 22/2420 [01:08<2:06:38,  3.17s/it]

Epoch [1/3], Batch [22/2420], Loss: 0.8770410418510437


  1%|▋                                                                             | 23/2420 [01:11<2:04:12,  3.11s/it]

Epoch [1/3], Batch [23/2420], Loss: 0.9328681230545044


  1%|▊                                                                             | 24/2420 [01:14<1:59:16,  2.99s/it]

Epoch [1/3], Batch [24/2420], Loss: 0.841187059879303


  1%|▊                                                                             | 25/2420 [01:16<1:52:14,  2.81s/it]

Epoch [1/3], Batch [25/2420], Loss: 0.8797575831413269


  1%|▊                                                                             | 26/2420 [01:19<1:47:51,  2.70s/it]

Epoch [1/3], Batch [26/2420], Loss: 0.6839094758033752


  1%|▊                                                                             | 27/2420 [01:22<1:57:37,  2.95s/it]

Epoch [1/3], Batch [27/2420], Loss: 0.8098555207252502


  1%|▉                                                                             | 28/2420 [01:27<2:13:43,  3.35s/it]

Epoch [1/3], Batch [28/2420], Loss: 0.6848344802856445


  1%|▉                                                                             | 29/2420 [01:31<2:31:22,  3.80s/it]

Epoch [1/3], Batch [29/2420], Loss: 0.9493051767349243


  1%|▉                                                                             | 30/2420 [01:35<2:27:49,  3.71s/it]

Epoch [1/3], Batch [30/2420], Loss: 0.7273212671279907


  1%|▉                                                                             | 31/2420 [01:39<2:30:16,  3.77s/it]

Epoch [1/3], Batch [31/2420], Loss: 0.8027989864349365


  1%|▉                                                                             | 31/2420 [01:41<2:10:02,  3.27s/it]


KeyboardInterrupt: 