In [1]:
import torch
import torch.nn as nn

In [2]:
class CBOW(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, literal_to_ix):
        super(CBOW, self).__init__()
        # out: 1 x embedding_dim
        self.vocab_size = vocab_size
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.literal_to_ix = literal_to_ix
        self.linear1 = nn.Linear(embedding_dim, 128)
        self.activation_function1 = nn.ReLU()
        
        # out: 1 x vocab_size
        self.linear2 = nn.Linear(128, vocab_size)
        self.activation_function2 = nn.LogSoftmax(dim=-1)
        
    def forward(self, inputs):
        embeds = sum(self.embeddings(inputs)).view(1, -1)
        out = self.linear1(embeds)
        out = self.activation_function1(out)
        out = self.linear2(out)
        out = self.activation_function2(out)
        return out
    
    def get_literal_embedding(self, literal):
        ix = torch.tensor([self.literal_to_ix[literal]])
        return self.embeddings(ix)

    def get_embeddings(self):
        ix = torch.tensor([i for i in range(self.vocab_size)])
        return self.embeddings(ix)

In [3]:
# utils 

def make_context_vector(context, literal_to_idx):
    idxs = [literal_to_idx[l] for l in context]
    return torch.tensor(idxs, dtype=torch.long)


def read_sat(sat_path):
    with open(sat_path) as f:
        sat_lines = f.readlines()
        header = sat_lines[0]
        header_info = header.replace("\n", "").split(" ")
        num_vars = int(header_info[-2])
        num_clauses = int(header_info[-1])

        sat = [[int(x) for x in line.replace(' 0\n', '').split(' ')]
               for line in sat_lines[1:]]

        return sat, num_vars, num_clauses

In [4]:
# data preprocessing

sat_path = './ssa2670-141.processed.cnf'
sat_instance, num_vars, num_clauses = read_sat(sat_path)
vocab_size = num_vars * 2

data = []
for clause in sat_instance:
    clause_len = len(clause)
    for i in range(clause_len):
        context = [clause[x] for x in range(clause_len) if x != i]
        target = clause[i]
        data.append((context, target))

print(f'data size: {len(data)}')

data size: 1161


In [5]:
# model setting

EMDEDDING_DIM = 50

literal_to_ix = {}
for i in range(1, num_vars + 1):
    literal_to_ix[i] = 2 * i - 2
    literal_to_ix[-i] = 2 * i - 1

model = CBOW(vocab_size, EMDEDDING_DIM, literal_to_ix)
loss_function = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

# training
for epoch in range(100):
    total_loss = 0
    for context, target in data:
        context_vector = make_context_vector(context, literal_to_ix)
        log_probs = model(context_vector)
        total_loss += loss_function(log_probs, torch.tensor([literal_to_ix[target]]))
    
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(epoch)

0
10
20
30
40
50
60
70
80
90


In [7]:
# test the embedding

print(model.get_literal_embedding(91))
embeddings = model.get_embeddings()
# torch.save(embeddings, './embeddings.pt')

tensor([[-0.4813, -0.4265, -0.0830,  0.2998,  1.7137, -0.0975, -1.7145,  0.0711,
          0.2901,  0.5749, -0.0985, -1.5374,  1.3501,  0.7128, -1.2989, -1.8904,
         -0.5634,  2.5485,  0.3694, -0.2581,  0.7689, -1.1553, -0.5662, -1.3081,
          0.3625, -0.8950, -0.5461,  0.3160,  0.2465, -0.0082, -1.3402, -1.2443,
         -0.9131, -0.6275, -0.4398, -0.3806,  0.5567, -1.3550,  0.1531, -1.0460,
          0.6436,  0.4479, -1.2855, -0.5104,  0.5272, -0.2633,  2.7738, -0.5985,
          0.9515, -1.5722]], grad_fn=<EmbeddingBackward0>)


In [8]:
embeddings[180]

tensor([-0.4813, -0.4265, -0.0830,  0.2998,  1.7137, -0.0975, -1.7145,  0.0711,
         0.2901,  0.5749, -0.0985, -1.5374,  1.3501,  0.7128, -1.2989, -1.8904,
        -0.5634,  2.5485,  0.3694, -0.2581,  0.7689, -1.1553, -0.5662, -1.3081,
         0.3625, -0.8950, -0.5461,  0.3160,  0.2465, -0.0082, -1.3402, -1.2443,
        -0.9131, -0.6275, -0.4398, -0.3806,  0.5567, -1.3550,  0.1531, -1.0460,
         0.6436,  0.4479, -1.2855, -0.5104,  0.5272, -0.2633,  2.7738, -0.5985,
         0.9515, -1.5722], grad_fn=<SelectBackward0>)