In [19]:
import torch
import torch.nn as nn
import torch.optim as optim

# 1. Preprocess the data (tokenize into 6-mers and create context pairs)
seqs = ['TCGTTG', 'CGTTGA', 'GTTGAG', 'TTGAGA', 'TGAGAT', 'GAGATG', 'AGATGG', 'GATGGG', 'ATGGGT', 'TGGGTC']
window_size = 2  # context size for skip-gram

In [20]:
def create_skipgram_pairs(seqs, window_size):
    pairs = []
    for seq in seqs:
        for i, word in enumerate(seq):
            # Define context window
            start = max(0, i - window_size)
            end = min(len(seq), i + window_size + 1)
            for j in range(start, end):
                if j != i:
                    pairs.append((word, seq[j]))  # (target, context)
    return pairs

pairs = create_skipgram_pairs(seqs, window_size)

In [22]:
vocab = list(set([c for seq in seqs for c in seq]))  # Vocabulary of characters (6-mers)
word2idx = {word: i for i, word in enumerate(vocab)}  # Map word to index
idx2word = {i: word for word, i in word2idx.items()}  # Map index to word

# 4. Create dataset of target and context pairs
target_words = [word2idx[word[0]] for word in pairs]  # Target words (6-mers)
context_words = [word2idx[word[1]] for word in pairs]  # Context words (characters)

In [21]:
class SkipGramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super(SkipGramModel, self).__init__()
        self.in_embedding = nn.Embedding(vocab_size, embedding_dim)
        self.out_embedding = nn.Embedding(vocab_size, embedding_dim)
    
    def forward(self, target):
        target_emb = self.in_embedding(target)  # Embedding of target
        context_emb = self.out_embedding.weight  # Context embeddings (output layer is shared)
        score = torch.matmul(target_emb, context_emb.t())  # Dot product with context embeddings
        return score

In [23]:
embedding_dim = 10
model = SkipGramModel(len(vocab), embedding_dim)

# 6. Set up optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.1)
criterion = nn.CrossEntropyLoss()  # This will be used to calculate loss over the context

In [24]:
epochs = 100
for epoch in range(epochs):
    total_loss = 0
    for target, context in zip(target_words, context_words):
        optimizer.zero_grad()
        
        # Convert target and context to tensors of appropriate shape
        target_tensor = torch.tensor([target])  # Shape (1,)
        context_tensor = torch.tensor([context])  # Shape (1,)
        
        # Forward pass
        score = model(target_tensor)  # Get score (logits) for context words
        loss = criterion(score, context_tensor)  # Loss against context word
        loss.backward()  # Backpropagate
        
        # Update weights
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss}")

Epoch 1/100, Loss: 285.3716446682811
Epoch 2/100, Loss: 224.45380972325802
Epoch 3/100, Loss: 211.65820042788982
Epoch 4/100, Loss: 206.90768520534039
Epoch 5/100, Loss: 205.70158591866493
Epoch 6/100, Loss: 205.92478255927563
Epoch 7/100, Loss: 206.45002116262913
Epoch 8/100, Loss: 206.81446540355682
Epoch 9/100, Loss: 206.9582861661911
Epoch 10/100, Loss: 206.96845388412476
Epoch 11/100, Loss: 206.92398042976856
Epoch 12/100, Loss: 206.86333034932613
Epoch 13/100, Loss: 206.80056509375572
Epoch 14/100, Loss: 206.7402051538229
Epoch 15/100, Loss: 206.683678150177
Epoch 16/100, Loss: 206.6312806904316
Epoch 17/100, Loss: 206.58279745280743
Epoch 18/100, Loss: 206.53781494498253
Epoch 19/100, Loss: 206.49590329825878
Epoch 20/100, Loss: 206.45674999058247
Epoch 21/100, Loss: 206.4201611727476
Epoch 22/100, Loss: 206.38606256246567
Epoch 23/100, Loss: 206.35443659126759
Epoch 24/100, Loss: 206.32528126239777
Epoch 25/100, Loss: 206.29859128594398
Epoch 26/100, Loss: 206.2743180990219
Epo

In [25]:
print("\nTrained Embeddings (for target words):")
for i in range(len(vocab)):
    print(f"{idx2word[i]}: {model.in_embedding.weight.data[i]}")


Trained Embeddings (for target words):
T: tensor([-0.3007, -0.3713,  0.1541, -0.0021, -0.3340,  0.3385,  0.3076, -0.8643,
         0.3777, -0.0364])
G: tensor([-0.0101,  0.5077,  0.2132,  0.2711, -0.1591,  0.5186,  0.0081, -0.3631,
        -0.1978,  0.1125])
C: tensor([-0.1077,  1.0323,  0.2655,  0.5090, -0.6059,  0.7848, -0.0206, -0.7407,
        -0.4208, -0.0199])
A: tensor([-0.1957,  0.0563,  0.2357,  0.1602, -0.2924,  0.5408,  0.2069, -0.7741,
         0.1306,  0.0580])
