<a href="https://colab.research.google.com/github/alpharomeo7/TLA-TCP/blob/master/TCP_simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install z3-solver



In [2]:
import random as r
import random
import string

def generate_random_string(length):
  """Generates a random string of the specified length."""
  letters = string.ascii_lowercase + string.ascii_uppercase
  return ''.join(random.choice(letters) for i in range(length))


class TCPPacket:
  def __init__(self, sender, payload, sq_num, ack, flags):
    self.sender = sender
    self.payload = payload
    self.sq_num = sq_num
    self.ack = ack
    self.flags = flags
    self.dropped = False

class Server:
  def __init__(self, name):
    self.name = name
    self.sq_num = 0
    self.ack = 0
    self.sent = []
    self.recv = None
    self.other = None
    self.started = False
    self.finished = False
    self.closed = False
    self.connected = False
    self.bytes_sent = 0
    self.bytes_recv= 0
    self.init_sqnum = r.randint(500,2000)
    self.init_ack = 0
    self.recv_payload = ""
    self.payload = generate_random_string(r.randint(200,200))


  def send_packet(self, msg_q):
    packet = self.make_packet()
    if packet is None:
      return
    self.sent += [packet]
    if "SYN" not in packet.flags and r.random() < 0.1:
      packet.dropped = True
    else:
      self.other.recv = packet
    msg_q += [packet]

  def recv_packet(self):
    if self.recv is not None:
      if self.finished and "FIN" in self.recv.flags:
        self.closed = True
        return
      if "SYN" in self.recv.flags:

        self.ack = self.recv.sq_num + 1
        if "ACK" in self.recv.flags:
          self.connected = self.started
          self.sq_num = self.recv.ack
        self.recv = None
        return
      elif self.finished == True and self.ack == self.recv.sq_num:
        self.recv_payload += self.recv.payload
        self.ack += len(self.recv.payload)
        self.recv = None
        return
      elif self.recv.ack > self.sq_num:
        self.connected = self.started
        self.sq_num = self.recv.ack
        self.recv_payload += self.recv.payload
        self.ack += len(self.recv.payload)
        self.bytes_recv += len(self.recv.payload)
        self.other.bytes_sent += len(self.recv.payload)
        self.recv = None
        return



  def make_packet(self):
    if self.closed:
      return None
    flags = []
    packet = None
    nbytes = r.randint(10,50)
    if self.connected == False:
      self.started = True
      flags.append("SYN")
      if self.ack > 0:
        flags.append("ACK")
      self.sq_num = self.init_sqnum
      packet =  TCPPacket(self.name, "",
                       self.sq_num, self.ack, flags)

    elif self.sq_num == self.init_sqnum + len(self.payload)  + 1:
      self.finished = True
      flags.append("FIN")
      flags.append("ACK")
      packet =  TCPPacket(self.name, "",
                       self.sq_num, self.ack, flags)
    else:
      packet = TCPPacket(self.name, self.payload[(self.sq_num - self.init_sqnum -1 ):(self.sq_num - self.init_sqnum -1) + nbytes],
                       self.sq_num, self.ack, ["ACK"])
    return packet






In [3]:
from z3 import *
def get_solver():

  vars = dict()
  vars['S1SQN'] = Int('S1SQN')
  vars['S1ACK'] = Int('S1ACK')
  vars['S2SQN'] = Int('S2SQN')
  vars['S2ACK'] = Int('S2ACK')
  vars['S1SENT'] = Int('S1SENT')
  vars['S2SENT'] = Int('S2SENT')
  vars['S1RECV'] = Int('S1RECV')
  vars['S2RECV'] = Int('S2RECV')

  s = Solver()

  inv = []
  inv.append(Not(vars['S1SQN'] <= vars['S2ACK']))
  inv.append(Not(vars['S2SQN'] <= vars['S1ACK']))
  inv.append(Not(vars['S1SENT'] == vars['S2RECV']))
  inv.append(Not(vars['S2SENT'] == vars['S1RECV']))
  s.add(Or(inv))

  return s,vars

In [4]:
msg_q = []
s1 = Server("s1")
s2 = Server("s2")
s1.other = s2
s2.other = s1
for i in range(2000):
  x = r.random()
  if x < 0.5:
    s1.send_packet(msg_q)
    s2.recv_packet()
  else:
    s2.send_packet(msg_q)
    s1.recv_packet()

  s,vars = get_solver()
  s.add(vars['S1SQN'] == s1.sq_num)
  s.add(vars['S2SQN'] == s2.sq_num)
  s.add(vars['S1ACK'] == s1.ack)
  s.add(vars['S2ACK'] == s2.ack)
  s.add(vars['S1SENT'] == s1.bytes_sent)
  s.add(vars['S2SENT'] == s2.bytes_sent)
  s.add(vars['S1RECV'] == s1.bytes_recv)
  s.add(vars['S2RECV'] == s2.bytes_recv)

  if s.check() != unsat:
    print("TCP is wrong?")

  if s1.finished and s2.finished:
    break #both servers are done


In [11]:
#both servers have finished sending
s1.finished and s2.finished

True

In [6]:
#check if sent payload is the same as recv payload on the other server
s1.payload == s2.recv_payload and s2.payload == s1.recv_payload

True