In [12]:
import unittest
import torch
import math

# assuming the full model code has already been written or imported above
class TestBigramLanguageModel(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        global device
        device = torch.device('cpu')  # Force CPU usage for testing
        
        # Load the pre-trained model
        cls.model = BigramLanguageModel()
        model_path = '/kaggle/input/lyrics-generator-by-spotify-million-song-dataset/pytorch/default/1/lyric_generator_model.pth'
        cls.model.load_state_dict(torch.load(model_path, map_location=device))
        cls.model = cls.model.to(device)
        
        # Sample text for testing
        cls.text = "This is a sample text for testing purposes."
        
        # Tokenize text using the model's vocabulary
        cls.data, cls.stoi, cls.itos, cls.vocab_size, cls.encode, cls.decode = cls.text_tokenize(cls.text, cls.model)

        # Split data
        cls.train_data, cls.val_data = train_test_split(cls.data, 0.8)

    @staticmethod
    def text_tokenize(text, model):
        # Use the model's vocab size
        stoi = model.token_embedding_table.weight.size(0)
        itos = {i: ch for i, ch in enumerate(sorted(set(text)))}
        stoi = {ch: i for i, ch in enumerate(sorted(set(text)))}

        encode = lambda s: [stoi[c] for c in s if c in stoi]
        decode = lambda l: ''.join([itos[i] for i in l if i in itos])

        data = torch.tensor(encode(text), dtype=torch.long)

        return data, stoi, itos, len(stoi), encode, decode

    def test_model_initialization(self):
        """Test if the model initializes correctly."""
        self.assertIsInstance(self.model, BigramLanguageModel)
        # Adjust to match the actual model's vocabulary size (77 in this case)
        self.assertEqual(self.model.token_embedding_table.num_embeddings, 77)

    def test_forward_pass(self):
        """Test the forward pass of the model."""
        x, y = get_batch('train')
        logits, loss = self.model(x, y)
        self.assertEqual(logits.shape, (batch_size * block_size, self.model.token_embedding_table.num_embeddings))
        self.assertIsInstance(loss.item(), float)

    def test_generate(self):
        """Test the text generation functionality."""
        context = torch.zeros((1, 1), dtype=torch.long, device=device)
        generated = self.model.generate(context, max_new_tokens=10)
        self.assertEqual(generated.shape, (1, 11))

    def test_save_load_model(self):
        """Test saving and loading the model."""
        model_path = 'test_model.pth'
        torch.save(self.model.state_dict(), model_path)
        loaded_model = BigramLanguageModel()
        loaded_model.load_state_dict(torch.load(model_path, map_location=device))
        loaded_model = loaded_model.to(device)

        x, y = get_batch('train')
        logits_before, loss_before = self.model(x, y)
        logits_after, loss_after = loaded_model(x, y)
        self.assertTrue(torch.allclose(logits_before, logits_after))
        self.assertAlmostEqual(loss_before.item(), loss_after.item())

    def test_estimate_loss(self):
        """Test the loss estimation function."""
        losses = estimate_loss()
        self.assertIn('train', losses)
        self.assertIn('val', losses)
        self.assertIsInstance(losses['train'].item(), float)
        self.assertIsInstance(losses['val'].item(), float)

# Run the tests
unittest.main(argv=[''], verbosity=2, exit=False)

ModuleNotFoundError: No module named 'bigram_model'